|
|
|
@ -32,9 +32,7 @@ enum OpInfoFillType {
|
|
|
|
|
kOpProtoAndCheckerMaker = 1,
|
|
|
|
|
kGradOpDescMaker = 2,
|
|
|
|
|
kVarTypeInference = 3,
|
|
|
|
|
kShapeInference = 4,
|
|
|
|
|
kEstimateFlops = 5,
|
|
|
|
|
kUnknown = -1
|
|
|
|
|
kShapeInference = 4
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -50,10 +48,8 @@ struct OpInfoFillTypeID {
|
|
|
|
|
? kVarTypeInference
|
|
|
|
|
: (std::is_base_of<InferShapeBase, T>::value
|
|
|
|
|
? kShapeInference
|
|
|
|
|
: (std::is_base_of<EstimateFlopsBase,
|
|
|
|
|
T>::value
|
|
|
|
|
? kEstimateFlops
|
|
|
|
|
: kUnknown)))));
|
|
|
|
|
: static_cast<OpInfoFillType>(
|
|
|
|
|
-1)))));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -143,16 +139,6 @@ struct OpInfoFiller<T, kShapeInference> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct OpInfoFiller<T, kEstimateFlops> {
|
|
|
|
|
void operator()(const char* op_tpe, OpInfo* info) const {
|
|
|
|
|
info->estimate_flops_ = [](InferShapeContext* ctx) {
|
|
|
|
|
T estimate_flops;
|
|
|
|
|
return estimate_flops(ctx);
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|