|
|
|
@ -42,6 +42,18 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ElementwiseOpInferVarType : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
|
auto x_name = op_desc.Input("X")[0];
|
|
|
|
|
auto out_name = op_desc.Output("Out")[0];
|
|
|
|
|
auto& x = block->FindRecursiveOrCreateVar(x_name);
|
|
|
|
|
auto& out = block->FindRecursiveOrCreateVar(out_name);
|
|
|
|
|
out.SetType(x.GetType());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() final {
|
|
|
|
@ -138,5 +150,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
}; \
|
|
|
|
|
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
|
|
|
|
|
__ElemwiseOp##op_type##Maker__, \
|
|
|
|
|
::paddle::operators::ElementwiseOpInferVarType, \
|
|
|
|
|
::paddle::framework::DefaultGradOpDescMaker<true>); \
|
|
|
|
|
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
|
|
|
|
|