@ -3,6 +3,7 @@
# include "paddle/framework/attr_checker.h"
//#include "paddle/framework/op_base.h"
# include <algorithm>
# include "paddle/framework/op_desc.pb.h"
# include "paddle/framework/op_proto.pb.h"
@ -64,36 +65,6 @@ struct AttrTypeHelper {
}
} ;
template < >
void AttrTypeHelper : : SetAttrType < int > ( AttrProto * attr ) {
attr - > set_type ( paddle : : framework : : AttrType : : INT ) ;
}
template < >
void AttrTypeHelper : : SetAttrType < float > ( AttrProto * attr ) {
attr - > set_type ( paddle : : framework : : AttrType : : FLOAT ) ;
}
template < >
void AttrTypeHelper : : SetAttrType < std : : string > ( AttrProto * attr ) {
attr - > set_type ( paddle : : framework : : AttrType : : STRING ) ;
}
template < >
void AttrTypeHelper : : SetAttrType < std : : vector < int > > ( AttrProto * attr ) {
attr - > set_type ( paddle : : framework : : AttrType : : INTS ) ;
}
template < >
void AttrTypeHelper : : SetAttrType < std : : vector < float > > ( AttrProto * attr ) {
attr - > set_type ( paddle : : framework : : AttrType : : FLOATS ) ;
}
template < >
void AttrTypeHelper : : SetAttrType < std : : vector < std : : string > > ( AttrProto * attr ) {
attr - > set_type ( paddle : : framework : : AttrType : : STRINGS ) ;
}
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public :
@ -103,22 +74,22 @@ class OpProtoAndCheckerMaker {
protected :
void AddInput ( const std : : string & name , const std : : string & comment ) {
auto input = proto_ - > mutable_inputs ( ) - > Add ( ) ;
* ( input - > mutable_name ( ) ) = name ;
* ( input - > mutable_comment ( ) ) = comment ;
* input - > mutable_name ( ) = name ;
* input - > mutable_comment ( ) = comment ;
}
void AddOutput ( const std : : string & name , const std : : string & comment ) {
auto output = proto_ - > mutable_outputs ( ) - > Add ( ) ;
* ( output - > mutable_name ( ) ) = name ;
* ( output - > mutable_comment ( ) ) = comment ;
* output - > mutable_name ( ) = name ;
* output - > mutable_comment ( ) = comment ;
}
template < typename T >
TypedAttrChecker < T > & AddAttr ( const std : : string & name ,
const std : : string & comment ) {
auto attr = proto_ - > mutable_attrs ( ) - > Add ( ) ;
* ( attr - > mutable_name ( ) ) = name ;
* ( attr - > mutable_comment ( ) ) = comment ;
* attr - > mutable_name ( ) = name ;
* attr - > mutable_comment ( ) = comment ;
AttrTypeHelper : : SetAttrType < T > ( attr ) ;
return op_checker_ - > AddAttrChecker < T > ( name ) ;
}
@ -134,49 +105,51 @@ class OpProtoAndCheckerMaker {
} ;
class OpRegistry {
typedef std : : function < OpBase * ( ) > OpCreator ;
using OpCreator = std : : function < OpBase * ( ) > ;
public :
template < typename OpType , typename ProtoMakerType >
static void RegisterOp ( const std : : string & op_type ) {
creators_ [ op_type ] = [ ] ( ) { return new OpType ; } ;
OpProto & op_proto = protos_ [ op_type ] ;
OpAttrChecker & op_checker = op_checkers_ [ op_type ] ;
creators ( ) [ op_type ] = [ ] { return new OpType ; } ;
OpProto & op_proto = protos ( ) [ op_type ] ;
OpAttrChecker & op_checker = op_checkers ( ) [ op_type ] ;
ProtoMakerType ( & op_proto , & op_checker ) ;
PADDLE_ENFORCE ( op_proto . IsInitialized ( ) = = true ,
PADDLE_ENFORCE ( op_proto . IsInitialized ( ) ,
" Fail to initialize %s's OpProto ! " , op_type ) ;
}
static OpBase * CreateOp ( const OpDesc & op_desc ) {
std : : string op_type = op_desc . type ( ) ;
OpBase * op = ( creators_ . at ( op_type ) ) ( ) ;
( op - > inputs_ ) . resize ( op_desc . inputs_size ( ) ) ;
for ( int i = 0 ; i < op_desc . inputs_size ( ) ; + + i ) {
( op - > inputs_ ) [ i ] = op_desc . inputs ( i ) ;
}
( op - > outputs_ ) . resize ( op_desc . outputs_size ( ) ) ;
for ( int i = 0 ; i < op_desc . outputs_size ( ) ; + + i ) {
( op - > outputs_ ) [ i ] = op_desc . outputs ( i ) ;
}
for ( int i = 0 ; i < op_desc . attrs_size ( ) ; + + i ) {
const AttrDesc & ith_attr = op_desc . attrs ( i ) ;
std : : string name = ith_attr . name ( ) ;
( op - > attr_map_ ) [ name ] = AttrTypeHelper : : GetAttrValue ( ith_attr ) ;
}
const OpAttrChecker & op_checker = op_checkers_ . at ( op_type ) ;
op_checker . Check ( op - > attr_map_ ) ;
OpBase * op = creators ( ) . at ( op_type ) ( ) ;
op - > inputs_ . reserve ( ( size_t ) op_desc . inputs_size ( ) ) ;
std : : copy ( op_desc . inputs ( ) . begin ( ) , op_desc . inputs ( ) . end ( ) ,
std : : back_inserter ( op - > inputs_ ) ) ;
op - > outputs_ . reserve ( ( size_t ) op_desc . outputs_size ( ) ) ;
std : : copy ( op_desc . outputs ( ) . begin ( ) , op_desc . outputs ( ) . end ( ) ,
std : : back_inserter ( op - > outputs_ ) ) ;
for ( auto & attr : op_desc . attrs ( ) ) {
op - > attr_map_ [ attr . name ( ) ] = AttrTypeHelper : : GetAttrValue ( attr ) ;
}
op_checkers ( ) . at ( op_type ) . Check ( op - > attr_map_ ) ;
return op ;
}
private :
static std : : unordered_map < std : : string , OpCreator > & creators ( ) {
static std : : unordered_map < std : : string , OpCreator > creators_ ;
return creators_ ;
}
static std : : unordered_map < std : : string , OpProto > & protos ( ) {
static std : : unordered_map < std : : string , OpProto > protos_ ;
static std : : unordered_map < std : : string , OpAttrChecker > op_checkers_ ;
return proto s_;
} ;
std : : unordered_map < std : : string , std : : function < OpBase * ( ) > > OpRegistry : : creators_ ;
std : : unordered_map < std : : string , OpProto > OpRegistry : : protos_ ;
std : : unordered_map < std : : string , OpAttrChecker > OpRegistry : : op_checkers_ ;
static std : : unordered_map < std : : string , OpAttrChecker > & op_checkers ( ) {
static std : : unordered_map < std : : string , OpAttrChecker > op_checkers_ ;
return op_checkers_ ;
} ;
} ;
template < typename OpType , typename ProtoMakerType >
class OpRegisterHelper {
@ -194,60 +167,5 @@ class OpRegisterHelper {
const OpRegisterHelper < __op_class , __op_maker_class > \
__op_class # # Register : : reg ( # __op_type ) ;
// Demos
class CosineOp : public OpBase {
public :
virtual std : : string Run ( ) const {
std : : string msg = " CosineOp runs! scale = " +
std : : to_string ( boost : : get < float > ( attr_map_ . at ( " scale " ) ) ) ;
return msg ;
}
} ;
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public :
CosineOpProtoAndCheckerMaker ( OpProto * proto , OpAttrChecker * op_checker )
: OpProtoAndCheckerMaker ( proto , op_checker ) {
AddInput ( " input " , " input of cosine op " ) ;
AddOutput ( " output " , " output of cosine op " ) ;
AddAttr < float > ( " scale " , " scale of cosine op " )
. SetDefault ( 1.0 )
. LargerThan ( 0.0 ) ;
AddType ( " cos " ) ;
AddComment ( " This is cos op " ) ;
}
} ;
REGISTER_OP ( CosineOp , CosineOpProtoAndCheckerMaker , cos_sim )
class MyTestOp : public OpBase {
public :
virtual std : : string Run ( ) const {
std : : string msg =
" MyTestOp runs! test_attr = " +
std : : to_string ( boost : : get < int > ( attr_map_ . at ( " test_attr " ) ) ) ;
return msg ;
}
} ;
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public :
MyTestOpProtoAndCheckerMaker ( OpProto * proto , OpAttrChecker * op_checker )
: OpProtoAndCheckerMaker ( proto , op_checker ) {
AddInput ( " input " , " input of cosine op " ) ;
AddOutput ( " output " , " output of cosine op " ) ;
auto my_checker = [ ] ( int i ) {
PADDLE_ENFORCE ( i % 2 = = 0 , " 'test_attr' must be even! " ) ;
} ;
AddAttr < int > ( " test_attr " , " a simple test attribute " )
. AddCustomChecker ( my_checker ) ;
AddType ( " my_test_op " ) ;
AddComment ( " This is my_test op " ) ;
}
} ;
REGISTER_OP ( MyTestOp , MyTestOpProtoAndCheckerMaker , my_test_op )
} // namespace framework
} // namespace paddle