|
|
@ -31,6 +31,12 @@ class InferShapeBase {
|
|
|
|
virtual void operator()(InferShapeContext*) const = 0;
|
|
|
|
virtual void operator()(InferShapeContext*) const = 0;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EstimateFlopsBase {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
virtual ~EstimateFlopsBase() = default;
|
|
|
|
|
|
|
|
virtual size_t operator()(InferShapeContext*) const = 0;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
struct OpInfo {
|
|
|
|
struct OpInfo {
|
|
|
|
OpCreator creator_;
|
|
|
|
OpCreator creator_;
|
|
|
|
GradOpMakerFN grad_op_maker_;
|
|
|
|
GradOpMakerFN grad_op_maker_;
|
|
|
@ -38,6 +44,7 @@ struct OpInfo {
|
|
|
|
OpAttrChecker* checker_{nullptr};
|
|
|
|
OpAttrChecker* checker_{nullptr};
|
|
|
|
InferVarTypeFN infer_var_type_;
|
|
|
|
InferVarTypeFN infer_var_type_;
|
|
|
|
InferShapeFN infer_shape_;
|
|
|
|
InferShapeFN infer_shape_;
|
|
|
|
|
|
|
|
EstimateFlopsFN estimate_flops_;
|
|
|
|
|
|
|
|
|
|
|
|
bool HasOpProtoAndChecker() const {
|
|
|
|
bool HasOpProtoAndChecker() const {
|
|
|
|
return proto_ != nullptr && checker_ != nullptr;
|
|
|
|
return proto_ != nullptr && checker_ != nullptr;
|
|
|
|