@ -1,6 +1,7 @@
# pragma once
# include <algorithm>
# include <atomic>
# include <type_traits>
# include <unordered_map>
# include <unordered_set>
@ -197,6 +198,8 @@ Add a mark to which output is temporary is helpful for future optimization.
class OpRegistry {
using OpCreator = std : : function < OperatorBase * ( ) > ;
using VarIndexMap = std : : unordered_map < std : : string , int > ;
using VarNameList = std : : vector < std : : string > ;
public :
template < typename OpType , typename ProtoMakerType >
@ -211,24 +214,64 @@ class OpRegistry {
op_proto . IsInitialized ( ) ,
" Fail to initialize %s's OpProto, because %s is not initialized " ,
op_type , op_proto . InitializationErrorString ( ) ) ;
VarIndexMaps ( ) [ op_type ] . reset ( new VarIndexMap ( ) ) ;
auto & varmap = * VarIndexMaps ( ) [ op_type ] ;
int idx = 0 ;
for ( auto & var : op_proto . inputs ( ) ) {
varmap [ var . name ( ) ] = idx + + ;
}
idx = 0 ;
for ( auto & var : op_proto . outputs ( ) ) {
varmap [ var . name ( ) ] = idx + + ;
}
}
static OperatorPtr CreateOp ( const std : : string & type ,
const VarNameList & inputs ,
const VarNameList & outputs ,
const AttributeMap & attrs ) {
auto op_create_it = creators ( ) . find ( type ) ;
PADDLE_ENFORCE ( op_create_it ! = creators ( ) . end ( ) ,
" Operator %s cannot be found " , type ) ;
auto op = op_create_it - > second ( ) ;
op - > type_ = type ;
op - > inputs_ = inputs ;
op - > outputs_ = outputs ;
op - > attrs_ = attrs ;
op_checkers ( ) . at ( type ) . Check ( op - > attrs_ ) ;
GenerateTempVariableName ( op ) ;
{
auto var_index_it = VarIndexMaps ( ) . find ( type ) ;
if ( var_index_it ! = VarIndexMaps ( ) . end ( ) ) {
op - > in_out_idxs_ = var_index_it - > second ;
}
}
op - > Init ( ) ;
return OperatorPtr ( op ) ;
}
static OperatorPtr CreateOp ( const OpDesc & op_desc ) {
std : : string op_type = op_desc . type ( ) ;
OperatorPtr op ( creators ( ) . at ( op_type ) ( ) ) ;
op - > type_ = op_desc . type ( ) ;
op - > inputs_ . reserve ( ( size_t ) op_desc . inputs_size ( ) ) ;
std : : vector < std : : string > inputs ;
inputs . reserve ( ( size_t ) op_desc . inputs_size ( ) ) ;
std : : copy ( op_desc . inputs ( ) . begin ( ) , op_desc . inputs ( ) . end ( ) ,
std : : back_inserter ( op - > inputs_ ) ) ;
op - > outputs_ . reserve ( ( size_t ) op_desc . outputs_size ( ) ) ;
std : : back_inserter ( inputs ) ) ;
std : : vector < std : : string > outputs ;
outputs . reserve ( ( size_t ) op_desc . outputs_size ( ) ) ;
std : : copy ( op_desc . outputs ( ) . begin ( ) , op_desc . outputs ( ) . end ( ) ,
std : : back_inserter ( op - > outputs_ ) ) ;
std : : back_inserter ( outputs ) ) ;
AttributeMap attrs ;
for ( auto & attr : op_desc . attrs ( ) ) {
op - > attrs_ [ attr . name ( ) ] = AttrTypeHelper : : GetAttrValue ( attr ) ;
attrs[ attr . name ( ) ] = AttrTypeHelper : : GetAttrValue ( attr ) ;
}
op_checkers ( ) . at ( op_type ) . Check ( op - > attrs_ ) ;
op - > Init ( ) ;
return op ;
return CreateOp ( op_desc . type ( ) , inputs , outputs , attrs ) ;
}
static std : : unordered_map < std : : string , OpProto > & protos ( ) {
@ -237,6 +280,23 @@ class OpRegistry {
} ;
private :
static std : : unordered_map < std : : string , std : : shared_ptr < VarIndexMap > > &
VarIndexMaps ( ) {
static std : : unordered_map < std : : string , std : : shared_ptr < VarIndexMap > > maps_ ;
return maps_ ;
}
static void GenerateTempVariableName ( OperatorBase * op ) {
static std : : atomic < size_t > gUniqId ( 0UL ) ;
for ( auto & outname : op - > outputs_ ) {
if ( outname = = OperatorBase : : TMP_VAR_NAME ( ) ) {
outname + = op - > type_ ;
outname + = " @ " ;
outname + = std : : to_string ( gUniqId . fetch_add ( 1 ) ) ;
}
}
}
static std : : unordered_map < std : : string , OpCreator > & creators ( ) {
static std : : unordered_map < std : : string , OpCreator > creators_ ;
return creators_ ;
@ -278,7 +338,7 @@ class OpRegisterHelper {
/**
* Macro to Register OperatorKernel .
*/
# define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \
# define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE ( \
__reg_op_kernel_ # # type # # _ # # DEVICE_TYPE # # __ , \
" REGISTER_OP_KERNEL must be in global namespace " ) ; \
@ -287,17 +347,19 @@ class OpRegisterHelper {
: : paddle : : framework : : OperatorWithKernel : : OpKernelKey key ; \
key . place_ = PlaceType ( ) ; \
: : paddle : : framework : : OperatorWithKernel : : AllOpKernels ( ) [ # type ] [ key ] \
. reset ( new KernelType( ) ) ; \
. reset ( new __VA_ARGS__( ) ) ; \
} \
} ; \
static __op_kernel_register__ # # type # # __ __reg_kernel_ # # type # # __ ; \
int __op_kernel_register_ # # type # # _handle_ # # DEVICE_TYPE # # __ ( ) { return 0 ; }
# define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL ( type , GPU , : : paddle : : platform : : GPUPlace , KernelType )
// (type, KernelType)
# define REGISTER_OP_GPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL ( type , GPU , : : paddle : : platform : : GPUPlace , __VA_ARGS__ )
# define REGISTER_OP_CPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL ( type , CPU , : : paddle : : platform : : CPUPlace , KernelType )
// (type, KernelType)
# define REGISTER_OP_CPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL ( type , CPU , : : paddle : : platform : : CPUPlace , __VA_ARGS__ )
/**
* Macro to mark what Operator and Kernel we will use and tell the compiler to