@ -2,6 +2,8 @@
# include <algorithm>
# include <type_traits>
# include <unordered_map>
# include <unordered_set>
# include "paddle/framework/attr_checker.h"
# include "paddle/framework/op_desc.pb.h"
# include "paddle/framework/op_proto.pb.h"
@ -59,25 +61,52 @@ class OpProtoAndCheckerMaker {
OpProtoAndCheckerMaker ( OpProto * proto , OpAttrChecker * op_checker )
: proto_ ( proto ) , op_checker_ ( op_checker ) { }
~ OpProtoAndCheckerMaker ( ) { CheckNoDuplicatedAttrs ( ) ; }
protected :
void AddInput ( const std : : string & name , const std : : string & comment ) {
void AddInput ( const std : : string & name , const std : : string & comment ,
bool multiple = false ) {
auto input = proto_ - > mutable_inputs ( ) - > Add ( ) ;
* input - > mutable_name ( ) = name ;
* input - > mutable_comment ( ) = comment ;
input - > set_multiple ( multiple ) ;
if ( multiple ) {
SetHasMultipleInput ( ) ;
}
}
void AddInputs ( const std : : string & name , const std : : string & comment ) {
AddInput ( name , comment , true ) ;
}
void AddOutput ( const std : : string & name , const std : : string & comment ) {
void AddOutput ( const std : : string & name , const std : : string & comment ,
bool temporary = false , bool multiple = false ) {
auto output = proto_ - > mutable_outputs ( ) - > Add ( ) ;
* output - > mutable_name ( ) = name ;
* output - > mutable_comment ( ) = comment ;
output - > set_multiple ( multiple ) ;
if ( multiple ) {
SetHasMultipleOutput ( ) ;
}
output - > set_temporary ( temporary ) ;
if ( temporary ) {
SetHasTemporaryOutput ( ) ;
}
}
void AddOutputs ( const std : : string & name , const std : : string & comment ,
bool temporary = false ) {
AddOutput ( name , comment , temporary , true ) ;
}
template < typename T >
TypedAttrChecker < T > & AddAttr ( const std : : string & name ,
const std : : string & comment ) {
const std : : string & comment ,
bool generated = false ) {
auto attr = proto_ - > mutable_attrs ( ) - > Add ( ) ;
* attr - > mutable_name ( ) = name ;
* attr - > mutable_comment ( ) = comment ;
attr - > set_generated ( generated ) ;
AttrTypeHelper : : SetAttrType < T > ( attr ) ;
return op_checker_ - > AddAttrChecker < T > ( name ) ;
}
@ -86,8 +115,70 @@ class OpProtoAndCheckerMaker {
* ( proto_ - > mutable_comment ( ) ) = comment ;
}
private :
void SetHasMultiple ( const std : : string & in_out , bool * flag ) {
if ( ! * flag ) {
AddAttr < std : : vector < int > > ( in_out + " _format " ,
" The multiple index of " + in_out +
" \n "
R " DOC(
This attribute is used by Paddle core framework . Paddle ' s Op support each input
or output could be a list of variable . This attribute is used to show how that
list organized .
e . g .
input = [ " a " , " b " , " c " , " d " , " e " , " f " ]
input_format = [ 0 , 4 , 5 , 6 ]
means
The number of all input variables this op is six , and they are segmented into
three inputs .
The first input is input [ 0 : 4 ] , second is input [ 4 : 5 ] , third is input [ 5 : 6 ] .
) DOC " ,
/*generated*/ true ) ;
* flag = true ;
}
}
void SetHasMultipleInput ( ) { SetHasMultiple ( " input " , & has_multiple_input_ ) ; }
void SetHasMultipleOutput ( ) {
SetHasMultiple ( " output " , & has_multiple_output_ ) ;
}
void SetHasTemporaryOutput ( ) {
if ( ! has_temporary_output_ ) {
AddAttr < std : : vector < int > > ( " temporary_index " ,
R " DOC(The temporary index of output.
Not all output of Paddle Op is used by user . For faster computation , each op
could output some its internal state to other op , other op could take that
output to make compute faster .
Add a mark to which output is temporary is helpful for future optimization .
) DOC " ,
/*generated*/ true )
. SetDefault ( std : : vector < int > ( ) ) ;
has_temporary_output_ = true ;
}
}
void CheckNoDuplicatedAttrs ( ) {
std : : unordered_set < std : : string > names ;
size_t cnt = 0 ;
for ( auto & attr : proto_ - > attrs ( ) ) {
names . insert ( attr . name ( ) ) ;
+ + cnt ;
}
PADDLE_ENFORCE ( names . size ( ) = = cnt ,
" Cannot register two attribute in same name! " ) ;
}
OpProto * proto_ ;
OpAttrChecker * op_checker_ ;
bool has_multiple_input_ { false } ;
bool has_multiple_output_ { false } ;
bool has_temporary_output_ { false } ;
} ;
class OpRegistry {