Add ElementwiseOpInferVarType

shanyi15-patch-2
chengduoZH 7 years ago
parent 0d49b92140
commit 53d19f5b1e

@ -29,8 +29,11 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,

@ -41,6 +41,16 @@ class ElementwiseOp : public framework::OperatorWithKernel {
}
};
class ElementwiseOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_var = op_desc.Input("X")[0];
auto out_var = op_desc.Output("Out")[0];
block->Var(out_var)->SetType(block->Var(x_var)->GetType());
}
};
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)

Loading…
Cancel
Save