@ -228,6 +228,11 @@ class OpRegistry {
}
}
template < typename OpType >
static void RegisterGradOp ( const std : : string & op_type ) {
grad_creators ( ) [ op_type ] = [ ] { return new OpType ; } ;
}
static OperatorPtr CreateOp ( const std : : string & type ,
const VarNameList & inputs ,
const VarNameList & outputs ,
@ -240,6 +245,7 @@ class OpRegistry {
op - > type_ = type ;
op - > inputs_ = inputs ;
op - > outputs_ = outputs ;
op - > attrs_ = attrs ;
op_checkers ( ) . at ( type ) . Check ( op - > attrs_ ) ;
@ -256,11 +262,6 @@ class OpRegistry {
return OperatorPtr ( op ) ;
}
template < typename OpType >
static void RegisterGradOp ( const std : : string & op_type ) {
grad_creators ( ) [ op_type ] = [ ] { return new OpType ; } ;
}
static OperatorPtr CreateOp ( const OpDesc & op_desc ) {
std : : vector < std : : string > inputs ;
inputs . reserve ( ( size_t ) op_desc . inputs_size ( ) ) ;
@ -280,19 +281,16 @@ class OpRegistry {
return CreateOp ( op_desc . type ( ) , inputs , outputs , attrs ) ;
}
static OperatorPtr CreateGradOp ( std : : shared_ptr < OperatorBase > op ) {
OperatorPtr op_grad ( grad_creators ( ) . at ( op - > type_ ) ( ) ) ;
op_grad - > type_ = op - > type_ ;
op_grad - > inputs_ . reserve ( op - > inputs_ . size ( ) ) ;
for ( auto & input : op - > inputs_ ) {
op_grad - > inputs_ . emplace_back ( input ) ;
op_grad - > outputs_ . emplace_back ( input + " @grad " ) ;
}
for ( auto & output : op - > outputs_ ) {
op_grad - > inputs_ . emplace_back ( output ) ;
op_grad - > inputs_ . emplace_back ( output + " @grad " ) ;
}
return op_grad ;
static OperatorPtr CreateGradOp ( OperatorPtr op ) {
OperatorPtr grad_op ( grad_creators ( ) . at ( op - > type_ ) ( ) ) ;
grad_op - > type_ = op - > type_ ;
AssembleGradInOut ( op , grad_op ) ;
GenerateGradArgOffset ( op , grad_op ) ;
GenerateGradAttr ( op , grad_op ) ;
grad_op - > Init ( ) ;
return grad_op ;
}
static std : : unordered_map < std : : string , OpProto > & protos ( ) {
@ -307,6 +305,21 @@ class OpRegistry {
return maps_ ;
}
static std : : unordered_map < std : : string , OpCreator > & creators ( ) {
static std : : unordered_map < std : : string , OpCreator > creators_ ;
return creators_ ;
}
static std : : unordered_map < std : : string , OpAttrChecker > & op_checkers ( ) {
static std : : unordered_map < std : : string , OpAttrChecker > op_checkers_ ;
return op_checkers_ ;
} ;
static std : : unordered_map < std : : string , OpCreator > & grad_creators ( ) {
static std : : unordered_map < std : : string , OpCreator > grad_creators_ ;
return grad_creators_ ;
}
static void GenerateTempVariableName ( OperatorBase * op ) {
static std : : atomic < size_t > gUniqId ( 0UL ) ;
for ( auto & outname : op - > outputs_ ) {
@ -318,19 +331,98 @@ class OpRegistry {
}
}
static std : : unordered_map < std : : string , OpCreator > & creators ( ) {
static std : : unordered_map < std : : string , OpCreator > creators_ ;
return creators_ ;
static void AssembleGradInOut ( OperatorPtr op , OperatorPtr grad_op ) {
size_t in_sz = op - > inputs_ . size ( ) + op - > outputs_ . size ( ) * 2 ;
grad_op - > inputs_ . reserve ( in_sz ) ;
size_t out_sz = op - > inputs_ . size ( ) ;
grad_op - > outputs_ . reserve ( out_sz ) ;
// copy op->inputs_ to grad_op->inputs_
std : : copy ( op - > inputs_ . begin ( ) , op - > inputs_ . end ( ) ,
std : : back_inserter ( grad_op - > inputs_ ) ) ;
// copy op->outputs_ to grad_op->inputs_
std : : copy ( op - > outputs_ . begin ( ) , op - > outputs_ . end ( ) ,
std : : back_inserter ( grad_op - > inputs_ ) ) ;
// add gradients of op->outputs_ to grad_op->inputs_
for ( const std : : string & name : op - > outputs_ ) {
grad_op - > inputs_ . emplace_back ( name + OperatorBase : : GRAD_VAR_SUFFIX ( ) ) ;
}
// add gradients of op->inputs_ to grad_op->outputs_
for ( const std : : string & name : op - > inputs_ ) {
grad_op - > outputs_ . emplace_back ( name + OperatorBase : : GRAD_VAR_SUFFIX ( ) ) ;
}
}
static std : : unordered_map < std : : string , OpAttrChecker > & op_checkers ( ) {
static std : : unordered_map < std : : string , OpAttrChecker > op_checkers_ ;
return op_checkers_ ;
} ;
static void GenerateGradArgOffset ( OperatorPtr op , OperatorPtr grad_op ) {
VarIndexMap * grad_varmap = new VarIndexMap ( ) ;
const OpProto & op_proto = protos ( ) [ op - > type_ ] ;
int idx = 0 ;
// offset of op's inputs
for ( const auto & var : op_proto . inputs ( ) ) {
( * grad_varmap ) [ var . name ( ) ] = idx + + ;
}
// offset of op's outputs
for ( const auto & var : op_proto . outputs ( ) ) {
( * grad_varmap ) [ var . name ( ) ] = idx + + ;
}
// offset of gradients of op's output
for ( const auto & var : op_proto . outputs ( ) ) {
( * grad_varmap ) [ var . name ( ) + OperatorBase : : GRAD_VAR_SUFFIX ( ) ] = idx + + ;
}
idx = 0 ;
// offset of gradients of op's input
for ( const auto & var : op_proto . inputs ( ) ) {
( * grad_varmap ) [ var . name ( ) + OperatorBase : : GRAD_VAR_SUFFIX ( ) ] = idx + + ;
}
grad_op - > in_out_idxs_ . reset ( grad_varmap ) ;
}
static std : : unordered_map < std : : string , OpCreator > & grad_creators ( ) {
static std : : unordered_map < std : : string , OpCreator > grad_creators_ ;
return grad_creators_ ;
static void GenerateGradAttr ( OperatorPtr op , OperatorPtr grad_op ) {
const OpProto & op_proto = protos ( ) [ op - > type_ ] ;
grad_op - > attrs_ = op - > attrs_ ;
grad_op - > attrs_ . erase ( " input_format " ) ;
grad_op - > attrs_ . erase ( " output_format " ) ;
bool has_in_format = op - > attrs_ . count ( " input_format " ) ;
bool has_out_format = op - > attrs_ . count ( " output_format " ) ;
// grad_op's inputs_ contains op's inputs_, outputs_ and gradients of
// outpus_. So grad_op's input_format is necessary when op has
// either input_format or output_format.
if ( has_in_format | | has_out_format ) {
std : : vector < int > old_in_format ;
std : : vector < int > old_out_format ;
has_in_format
? old_in_format = op - > GetAttr < std : : vector < int > > ( " input_format " )
: old_in_format = std : : vector < int > ( op_proto . inputs_size ( ) ) ,
std : : iota ( old_in_format . begin ( ) , old_in_format . end ( ) , 0 ) ;
has_out_format
? old_out_format = op - > GetAttr < std : : vector < int > > ( " output_format " )
: old_out_format = std : : vector < int > ( op_proto . outputs_size ( ) ) ,
std : : iota ( old_out_format . begin ( ) , old_out_format . end ( ) , 0 ) ;
std : : vector < int > in_format ;
in_format . reserve ( old_in_format . size ( ) + old_out_format . size ( ) * 2 ) ;
int base = 0 ;
for ( const int & idx : old_in_format ) {
in_format . emplace_back ( idx + base ) ;
}
base + = op - > inputs_ . size ( ) ;
for ( const int & idx : old_out_format ) {
in_format . emplace_back ( idx + base ) ;
}
base + = op - > outputs_ . size ( ) ;
for ( const int & idx : old_in_format ) {
in_format . emplace_back ( idx + base ) ;
}
grad_op - > attrs_ [ " input_format " ] = in_format ;
// grad_op's outputs_ contains gradients of op's inputs_. So grad_op's
// output_format is necessary only when op has input_format.
if ( has_in_format ) {
std : : vector < int > out_format ;
out_format . reserve ( op_proto . inputs_size ( ) ) ;
std : : copy ( old_in_format . begin ( ) , old_in_format . end ( ) ,
std : : back_inserter ( out_format ) ) ;
grad_op - > attrs_ [ " output_format " ] = out_format ;
}
}
}
} ;
@ -370,7 +462,7 @@ class GradOpRegisterHelper {
int __op_register_ # # __op_type # # _handle__ ( ) { return 0 ; }
/**
* Macro to Register Operator.
* Macro to Register Gradient Operator.
*/
# define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE ( \