@ -270,6 +270,19 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return false ;
return false ;
}
}
}
}
if ( VLOG_IS_ON ( 10 ) ) {
std : : ostringstream sout ;
sout < < " All input { " ;
for ( auto & name : names ) {
sout < < name < < " , " ;
}
sout < < " } is in { " ;
for ( auto & name : set ) {
sout < < name < < " , " ;
}
sout < < " } " ;
VLOG ( 10 ) < < sout . str ( ) ;
}
return true ;
return true ;
}
}
@ -290,14 +303,12 @@ static void CreateGradVarInBlock(
auto ops = block_desc - > AllOps ( ) ;
auto ops = block_desc - > AllOps ( ) ;
for ( size_t op_index = grad_op_start_index ; op_index < ops . size ( ) ;
for ( size_t op_index = grad_op_start_index ; op_index < ops . size ( ) ;
+ + op_index ) {
+ + op_index ) {
bool need_infer_shape = false ;
std : : unordered_set < std : : string > new_vars ;
std : : unordered_set < std : : string > new_vars ;
ForEachVarName ( ops [ op_index ] - > Outputs ( ) ,
ForEachVarName ( ops [ op_index ] - > Outputs ( ) ,
[ & ] ( const std : : string & grad_var_name ) {
[ & ] ( const std : : string & grad_var_name ) {
if ( block_desc - > HasVar ( grad_var_name ) ) {
if ( block_desc - > HasVar ( grad_var_name ) ) {
return false ;
return false ;
}
}
need_infer_shape = true ;
auto var = block_desc - > Var ( grad_var_name ) ;
auto var = block_desc - > Var ( grad_var_name ) ;
new_vars . insert ( var - > Name ( ) ) ;
new_vars . insert ( var - > Name ( ) ) ;
auto it = param_name_map . find ( grad_var_name ) ;
auto it = param_name_map . find ( grad_var_name ) ;
@ -311,7 +322,6 @@ static void CreateGradVarInBlock(
grad_record . op_idx_ = static_cast < int > ( op_index ) ;
grad_record . op_idx_ = static_cast < int > ( op_index ) ;
return false ; /* not break */
return false ; /* not break */
} ) ;
} ) ;
if ( need_infer_shape ) {
ops [ op_index ] - > InferVarType ( block_desc ) ;
ops [ op_index ] - > InferVarType ( block_desc ) ;
for ( auto & arg : ops [ op_index ] - > OutputArgumentNames ( ) ) {
for ( auto & arg : ops [ op_index ] - > OutputArgumentNames ( ) ) {
if ( new_vars . find ( arg ) = = new_vars . end ( ) ) {
if ( new_vars . find ( arg ) = = new_vars . end ( ) ) {
@ -329,7 +339,6 @@ static void CreateGradVarInBlock(
ops [ op_index ] - > InferShape ( * block_desc ) ;
ops [ op_index ] - > InferShape ( * block_desc ) ;
}
}
}
}
}
std : : vector < std : : unique_ptr < OpDescBind > > MakeOpGrad (
std : : vector < std : : unique_ptr < OpDescBind > > MakeOpGrad (
const OpDescBind * op_desc , std : : unordered_set < std : : string > * no_grad_vars ,
const OpDescBind * op_desc , std : : unordered_set < std : : string > * no_grad_vars ,
@ -387,6 +396,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind & program_desc , int block_idx ,
ProgramDescBind & program_desc , int block_idx ,
std : : unordered_set < std : : string > * no_grad_vars ,
std : : unordered_set < std : : string > * no_grad_vars ,
std : : unordered_map < std : : string , std : : string > * grad_to_var ) {
std : : unordered_map < std : : string , std : : string > * grad_to_var ) {
VLOG ( 5 ) < < " MakeBlockBackward " ;
BlockDescBind * cur_block = program_desc . MutableBlock ( block_idx ) ;
BlockDescBind * cur_block = program_desc . MutableBlock ( block_idx ) ;
std : : vector < OpDescBind * > op_descs = cur_block - > AllOps ( ) ;
std : : vector < OpDescBind * > op_descs = cur_block - > AllOps ( ) ;
std : : unordered_map < std : : string , std : : vector < size_t > > dup_out_ops ;
std : : unordered_map < std : : string , std : : vector < size_t > > dup_out_ops ;
@ -394,9 +404,10 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std : : vector < std : : unique_ptr < OpDescBind > > backward_descs ;
std : : vector < std : : unique_ptr < OpDescBind > > backward_descs ;
for ( auto it = op_descs . rbegin ( ) ; it ! = op_descs . rend ( ) ; + + it ) {
for ( auto it = op_descs . rbegin ( ) ; it ! = op_descs . rend ( ) ; + + it ) {
VLOG ( 5 ) < < " Making backward " < < ( * it ) - > Type ( ) < < " op " ;
std : : vector < std : : unique_ptr < OpDescBind > > op_grads ;
std : : vector < std : : unique_ptr < OpDescBind > > op_grads ;
if ( ( * it ) - > Type ( ) = = " recurrent " ) {
if ( ( * it ) - > Type ( ) = = " recurrent " | | ( * it ) - > Type ( ) = = " while " ) {
int step_block_idx = ( * it ) - > GetBlockAttr ( " step_block " ) ;
int step_block_idx = ( * it ) - > GetBlockAttr ( " step_block " ) ;
BlockDescBind * backward_block = CreateStepBlock (
BlockDescBind * backward_block = CreateStepBlock (
program_desc , no_grad_vars , grad_to_var , step_block_idx ) ;
program_desc , no_grad_vars , grad_to_var , step_block_idx ) ;
@ -410,6 +421,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
op_grads = MakeOpGrad ( * it , no_grad_vars , grad_to_var ) ;
op_grads = MakeOpGrad ( * it , no_grad_vars , grad_to_var ) ;
}
}
if ( VLOG_IS_ON ( 10 ) ) {
std : : ostringstream sout ;
sout < < " Made " ;
for ( auto & op_grad : op_grads ) {
sout < < op_grad - > Type ( ) < < " " ;
}
VLOG ( 10 ) < < sout . str ( ) ;
}
for ( const auto & desc : op_grads ) {
for ( const auto & desc : op_grads ) {
for ( const std : : string & out_name : desc - > OutputArgumentNames ( ) ) {
for ( const std : : string & out_name : desc - > OutputArgumentNames ( ) ) {
if ( out_name . find ( " @GRAD " ) = = std : : string : : npos ) {
if ( out_name . find ( " @GRAD " ) = = std : : string : : npos ) {
@ -425,6 +445,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
op_grads . begin ( ) , op_grads . end ( ) , std : : back_inserter ( backward_descs ) ,
op_grads . begin ( ) , op_grads . end ( ) , std : : back_inserter ( backward_descs ) ,
[ ] ( std : : unique_ptr < OpDescBind > & ptr ) { return std : : move ( ptr ) ; } ) ;
[ ] ( std : : unique_ptr < OpDescBind > & ptr ) { return std : : move ( ptr ) ; } ) ;
}
}
VLOG ( 5 ) < < " Appending Sums " ;
// Check whether some variables are written more than once
// Check whether some variables are written more than once
std : : list < std : : pair < size_t , std : : unique_ptr < OpDescBind > > > pending_sum_ops ;
std : : list < std : : pair < size_t , std : : unique_ptr < OpDescBind > > > pending_sum_ops ;
for ( const auto & dup : dup_out_ops ) {
for ( const auto & dup : dup_out_ops ) {
@ -432,16 +454,22 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
const std : : vector < size_t > dup_op = dup . second ;
const std : : vector < size_t > dup_op = dup . second ;
if ( out_name ! = kEmptyVarName & & dup_op . size ( ) > 1 ) {
if ( out_name ! = kEmptyVarName & & dup_op . size ( ) > 1 ) {
std : : vector < std : : string > sum_op_inputs ;
std : : vector < std : : string > sum_op_inputs ;
std : : string next_g_name = out_name ;
for ( size_t i = 0 ; i < dup_op . size ( ) ; + + i ) {
for ( size_t i = 0 ; i < dup_op . size ( ) ; + + i ) {
VLOG ( 10 ) < < backward_descs [ dup_op [ i ] ] - > Type ( ) < < " has " < < out_name
< < " duplicated " ;
std : : string new_name = out_name + " @RENAME@ " + std : : to_string ( i ) ;
std : : string new_name = out_name + " @RENAME@ " + std : : to_string ( i ) ;
backward_descs [ dup_op [ i ] ] - > Rename ( out_name , new_name ) ;
backward_descs [ dup_op [ i ] ] - > RenameOutput ( out_name , new_name ) ;
backward_descs [ dup_op [ i ] ] - > RenameInput ( out_name , next_g_name ) ;
sum_op_inputs . emplace_back ( new_name ) ;
sum_op_inputs . emplace_back ( new_name ) ;
next_g_name = sum_op_inputs . back ( ) ;
}
}
std : : unique_ptr < OpDescBind > sum_op ( new OpDescBind (
std : : unique_ptr < OpDescBind > sum_op ( new OpDescBind (
" sum " , { { " X " , sum_op_inputs } } , { { " Out " , { out_name } } } , { } ) ) ;
" sum " , { { " X " , sum_op_inputs } } , { { " Out " , { out_name } } } , { } ) ) ;
pending_sum_ops . push_back ( { dup_op . back ( ) , std : : move ( sum_op ) } ) ;
pending_sum_ops . push_back ( { dup_op . back ( ) , std : : move ( sum_op ) } ) ;
}
}
}
}
pending_sum_ops . sort (
pending_sum_ops . sort (
[ ] ( const std : : pair < size_t , std : : unique_ptr < OpDescBind > > & a ,
[ ] ( const std : : pair < size_t , std : : unique_ptr < OpDescBind > > & a ,
const std : : pair < size_t , std : : unique_ptr < OpDescBind > > & b ) {
const std : : pair < size_t , std : : unique_ptr < OpDescBind > > & b ) {
@ -452,6 +480,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std : : move ( p . second ) ) ;
std : : move ( p . second ) ) ;
}
}
VLOG ( 5 ) < < " MakeBlockBackward Finished " ;
return backward_descs ;
return backward_descs ;
}
}