@ -198,6 +198,7 @@ Add a mark to which output is temporary is helpful for future optimization.
class OpRegistry {
using OpCreator = std : : function < OperatorBase * ( ) > ;
using VarIndexMap = std : : unordered_map < std : : string , int > ;
public :
template < typename OpType , typename ProtoMakerType >
@ -212,6 +213,17 @@ class OpRegistry {
op_proto . IsInitialized ( ) ,
" Fail to initialize %s's OpProto, because %s is not initialized " ,
op_type , op_proto . InitializationErrorString ( ) ) ;
VarIndexMaps ( ) [ op_type ] . reset ( new VarIndexMap ( ) ) ;
auto & varmap = * VarIndexMaps ( ) [ op_type ] ;
int idx = 0 ;
for ( auto & var : op_proto . inputs ( ) ) {
varmap [ var . name ( ) ] = idx + + ;
}
idx = 0 ;
for ( auto & var : op_proto . outputs ( ) ) {
varmap [ var . name ( ) ] = idx + + ;
}
}
static OperatorPtr CreateOp ( const OpDesc & op_desc ) {
@ -220,7 +232,6 @@ class OpRegistry {
OperatorPtr op ( creators ( ) . at ( op_type ) ( ) ) ;
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
const OpProto & op_proto = protos ( ) . at ( op_type ) ;
op - > type_ = op_desc . type ( ) ;
// set op's inputs_ from desc.
op - > inputs_ . reserve ( ( size_t ) op_desc . inputs_size ( ) ) ;
@ -240,25 +251,31 @@ class OpRegistry {
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName ( op . get ( ) ) ;
// set argument offsets stored in op.
CreateInOutOffsetMap ( op , op_proto ) ;
//! set argument offsets stored in op.
{
auto var_index_it = VarIndexMaps ( ) . find ( op_type ) ;
if ( var_index_it ! = VarIndexMaps ( ) . end ( ) ) {
op - > in_out_idxs_ = var_index_it - > second ;
}
}
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op - > Init ( ) ;
return op ;
}
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static void CreateInOutOffsetMap ( OperatorPtr op , const OpProto & proto ) {
op - > CreateInOutOffsetMap ( proto ) ;
}
static std : : unordered_map < std : : string , OpProto > & protos ( ) {
static std : : unordered_map < std : : string , OpProto > protos_ ;
return protos_ ;
} ;
private :
static std : : unordered_map < std : : string , std : : shared_ptr < VarIndexMap > > &
VarIndexMaps ( ) {
static std : : unordered_map < std : : string , std : : shared_ptr < VarIndexMap > > maps_ ;
return maps_ ;
}
static void GenerateTempVariableName ( OperatorBase * op ) {
static std : : atomic < size_t > gUniqId ( 0UL ) ;
for ( auto & outname : op - > outputs_ ) {
@ -311,7 +328,7 @@ class OpRegisterHelper {
/**
* Macro to Register OperatorKernel .
*/
# define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \
# define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE ( \
__reg_op_kernel_ # # type # # _ # # DEVICE_TYPE # # __ , \
" REGISTER_OP_KERNEL must be in global namespace " ) ; \
@ -320,17 +337,19 @@ class OpRegisterHelper {
: : paddle : : framework : : OperatorWithKernel : : OpKernelKey key ; \
key . place_ = PlaceType ( ) ; \
: : paddle : : framework : : OperatorWithKernel : : AllOpKernels ( ) [ # type ] [ key ] \
. reset ( new KernelType( ) ) ; \
. reset ( new __VA_ARGS__( ) ) ; \
} \
} ; \
static __op_kernel_register__ # # type # # __ __reg_kernel_ # # type # # __ ; \
int __op_kernel_register_ # # type # # _handle_ # # DEVICE_TYPE # # __ ( ) { return 0 ; }
# define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL ( type , GPU , : : paddle : : platform : : GPUPlace , KernelType )
// (type, KernelType)
# define REGISTER_OP_GPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL ( type , GPU , : : paddle : : platform : : GPUPlace , __VA_ARGS__ )
# define REGISTER_OP_CPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL ( type , CPU , : : paddle : : platform : : CPUPlace , KernelType )
// (type, KernelType)
# define REGISTER_OP_CPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL ( type , CPU , : : paddle : : platform : : CPUPlace , __VA_ARGS__ )
/**
* Macro to mark what Operator and Kernel we will use and tell the compiler to