@ -204,68 +204,38 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
// Record Op infershape core function
using InferShapeFunc = std : : vector < std : : vector < int64_t > > ( * ) (
const std : : vector < std : : vector < int64_t > > & input_shapes ,
const std : : vector < std : : vector < std : : vector < int64_t > > > & vec_input_shapes ,
const std : : vector < boost : : any > & attrs ) ;
const std : : vector < std : : vector < std : : vector < int64_t > > > & vec_input_shapes ) ;
# define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \
template < typename . . . Tail > \
struct InferShapeCallHelper < input_type , Tail . . . > { \
template < int in_idx , int vec_in_idx , int attr_idx , \
typename . . . PreviousArgs > \
static Return InferShape ( \
const std : : vector < std : : vector < int64_t > > & input_shapes , \
const std : : vector < std : : vector < std : : vector < int64_t > > > & \
vec_input_shapes , \
const std : : vector < boost : : any > & attrs , const PreviousArgs & . . . pargs ) { \
input_type arg = input_shapes [ in_idx ] ; \
return InferShapeCallHelper < Tail . . . > : : template InferShape < \
in_idx + 1 , vec_in_idx , attr_idx > ( input_shapes , vec_input_shapes , \
attrs , pargs . . . , arg ) ; \
} \
}
# define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \
template < typename . . . Tail > \
struct InferShapeCallHelper < input_type , Tail . . . > { \
template < int in_idx , int vec_in_idx , int attr_idx , \
typename . . . PreviousArgs > \
static Return InferShape ( \
const std : : vector < std : : vector < int64_t > > & input_shapes , \
const std : : vector < std : : vector < std : : vector < int64_t > > > & \
vec_input_shapes , \
const std : : vector < boost : : any > & attrs , const PreviousArgs & . . . pargs ) { \
input_type arg = vec_input_shapes [ vec_in_idx ] ; \
return InferShapeCallHelper < Tail . . . > : : template InferShape < \
in_idx , vec_in_idx + 1 , attr_idx > ( input_shapes , vec_input_shapes , \
attrs , pargs . . . , arg ) ; \
} \
# define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \
template < typename . . . Tail > \
struct InferShapeCallHelper < input_type , Tail . . . > { \
template < int in_idx , int vec_in_idx , typename . . . PreviousArgs > \
static Return InferShape ( \
const std : : vector < std : : vector < int64_t > > & input_shapes , \
const std : : vector < std : : vector < std : : vector < int64_t > > > & \
vec_input_shapes , \
const PreviousArgs & . . . pargs ) { \
input_type arg = input_shapes [ in_idx ] ; \
return InferShapeCallHelper < Tail . . . > : : template InferShape < in_idx + 1 , \
vec_in_idx > ( \
input_shapes , vec_input_shapes , pargs . . . , arg ) ; \
} \
}
# define PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(attr_type) \
template < typename . . . Tail > \
struct InferShapeCallHelper < attr_type , Tail . . . > { \
template < int in_idx , int vec_in_idx , int attr_idx , \
typename . . . PreviousArgs > \
static Return InferShape ( \
const std : : vector < std : : vector < int64_t > > & input_shapes , \
const std : : vector < std : : vector < std : : vector < int64_t > > > & \
vec_input_shapes , \
const std : : vector < boost : : any > & attrs , const PreviousArgs & . . . pargs ) { \
try { \
attr_type arg = boost : : any_cast < attr_type > ( attrs [ attr_idx ] ) ; \
return InferShapeCallHelper < Tail . . . > : : template InferShape < \
in_idx , vec_in_idx , attr_idx + 1 > ( input_shapes , vec_input_shapes , \
attrs , pargs . . . , arg ) ; \
} catch ( boost : : bad_any_cast & ) { \
PD_THROW ( \
" Attribute cast error in custom operator InferShapeFn. " \
" Expected " # attr_type \
" value. InferShapeFn's attribute list must be exactly same as " \
" Forward " \
" KernelFn's attribute list except std::vector<int64_t> " \
" attribute. " ) ; \
} \
} \
# define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \
template < typename . . . Tail > \
struct InferShapeCallHelper < input_type , Tail . . . > { \
template < int in_idx , int vec_in_idx , typename . . . PreviousArgs > \
static Return InferShape ( \
const std : : vector < std : : vector < int64_t > > & input_shapes , \
const std : : vector < std : : vector < std : : vector < int64_t > > > & \
vec_input_shapes , \
const PreviousArgs & . . . pargs ) { \
input_type arg = vec_input_shapes [ vec_in_idx ] ; \
return InferShapeCallHelper < Tail . . . > : : template InferShape < \
in_idx , vec_in_idx + 1 > ( input_shapes , vec_input_shapes , pargs . . . , \
arg ) ; \
} \
}
template < typename F , F f >
@ -275,10 +245,10 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct InferShapeFuncImpl < Return ( * ) ( Args . . . ) , impl_fn > {
static Return InferShape (
const std : : vector < std : : vector < int64_t > > & input_shapes ,
const std : : vector < std : : vector < std : : vector < int64_t > > > & vec_input_shapes ,
const std : : vector < boost : : any > & attrs ) {
return InferShapeCallHelper < Args . . . , TypeTag < int > > : : template InferShape <
0 , 0 , 0 > ( input_shapes , vec_input_shape s, attr s) ;
const std : : vector < std : : vector < std : : vector < int64_t > > > & vec_input_shapes ) {
return InferShapeCallHelper < Args . . . , TypeTag < int > > : : template InferShape < 0 ,
0 > (
input_shapes , vec_input_shape s) ;
}
private :
@ -295,26 +265,14 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES (
std : : vector < std : : vector < int64_t > > ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const bool & ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const int & ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const float & ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const int64_t & ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const std : : string & ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const std : : vector < int > & ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const std : : vector < float > & ) ;
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR ( const std : : vector < std : : string > & ) ;
// NOTE(chenweihang): InferShape can't support std::vector<int64_t> attr type,
// because the input type is std::vector<int64_t>, only can use one rule to
// parse std::vector<int64_t> parameter
// end: base template
template < typename T >
struct InferShapeCallHelper < TypeTag < T > > {
template < int in_idx , int vec_in_idx , int attr_idx >
template < int in_idx , int vec_in_idx >
static Return InferShape (
const std : : vector < std : : vector < int64_t > > & input_shapes ,
const std : : vector < std : : vector < std : : vector < int64_t > > > & vec_input_shapes ,
const std: : vector < boost : : any > & attrs , const Args& . . . args ) {
const Args & . . . args ) {
return impl_fn ( args . . . ) ;
}
} ;