@ -18,6 +18,7 @@
# include <deque>
# include <list>
# include <memory>
# include <unordered_set>
# include "paddle/framework/block_desc.h"
# include "paddle/framework/op_registry.h"
@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true ;
}
static std : : string FwdName ( const std : : string & grad_name ) {
auto pos = grad_name . find ( " @GRAD " ) ;
if ( pos = = std : : string : : npos ) {
return " " ;
} else {
return grad_name . substr ( 0 , pos ) ;
}
}
static void CreateGradVarInBlock (
size_t grad_op_start_index ,
const std : : unordered_map < std : : string , std : : string > & param_name_map ,
@ -294,6 +304,7 @@ static void CreateGradVarInBlock(
for ( size_t op_index = grad_op_start_index ; op_index < ops . size ( ) ;
+ + op_index ) {
bool need_infer_shape = false ;
std : : unordered_set < std : : string > new_vars ;
ForEachVarName ( ops [ op_index ] - > Outputs ( ) ,
[ & ] ( const std : : string & grad_var_name ) {
if ( block_desc - > HasVar ( grad_var_name ) ) {
@ -301,8 +312,7 @@ static void CreateGradVarInBlock(
}
need_infer_shape = true ;
auto var = block_desc - > Var ( grad_var_name ) ;
// FIXME(qiao) infer the datatype
var - > SetDataType ( framework : : DataType : : FP32 ) ;
new_vars . insert ( var - > Name ( ) ) ;
auto it = param_name_map . find ( grad_var_name ) ;
if ( it = = param_name_map . end ( ) ) {
return false ;
@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
} ) ;
if ( need_infer_shape ) {
ops [ op_index ] - > InferVarType ( block_desc ) ;
for ( auto & arg : ops [ op_index ] - > OutputArgumentNames ( ) ) {
if ( new_vars . find ( arg ) = = new_vars . end ( ) ) {
continue ;
}
auto pname = FwdName ( arg ) ;
auto * param = block_desc - > FindVar ( pname ) ;
auto * grad = block_desc - > FindVar ( arg ) ;
if ( param = = nullptr ) {
LOG ( WARNING ) < < " Cannot find forward variable of " < < arg
< < " . Set its gradient to FP32 " ;
grad - > SetDataType ( DataType : : FP32 ) ;
} else {
grad - > SetDataType ( param - > GetDataType ( ) ) ;
}
}
ops [ op_index ] - > InferShape ( * block_desc ) ;
}
}
@ -368,7 +393,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind & program_desc , int block_idx ,
std : : unordered_set < std : : string > * no_grad_vars ,
std : : unordered_map < std : : string , std : : string > * grad_to_var ) {
BlockDescBind * cur_block = program_desc . Block( block_idx ) ;
BlockDescBind * cur_block = program_desc . Mutable Block( block_idx ) ;
std : : vector < OpDescBind * > op_descs = cur_block - > AllOps ( ) ;
std : : unordered_map < std : : string , std : : vector < size_t > > dup_out_ops ;
size_t grad_desc_idx = 0 ;
@ -443,7 +468,7 @@ ParamGradInfoMap AppendBackward(
}
const int root_block_idx = 0 ;
auto root_block = program_desc . Block( root_block_idx ) ;
auto root_block = program_desc . Mutable Block( root_block_idx ) ;
// insert fill one op for target
// TODO(qiao) add some check to the target.
@ -492,7 +517,7 @@ ParamGradInfoMap AppendBackward(
CreateGradVarInBlock ( forward_op_num , grad_to_var , root_block , & retv ) ;
for ( size_t block_index = forward_block_num ;
block_index < program_desc . Size ( ) ; + + block_index ) {
CreateGradVarInBlock ( 0 , grad_to_var , program_desc . Block( block_index ) ,
CreateGradVarInBlock ( 0 , grad_to_var , program_desc . Mutable Block( block_index ) ,
& retv ) ;
}
return retv ;