|
|
|
@ -264,6 +264,23 @@ class ElementwiseOpInplace : public framework::InplaceInToOut {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ElementwiseGradOpInplace : public framework::InplaceInToOut {
|
|
|
|
|
public:
|
|
|
|
|
using framework::InplaceInToOut::InplaceInToOut;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unordered_map<std::string, std::string> Apply(
|
|
|
|
|
const framework::OpDesc &op_desc,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
std::unordered_map<std::string, std::string> ret;
|
|
|
|
|
if (block->HasVar(framework::GradVarName("X")) &&
|
|
|
|
|
block->HasVar(framework::GradVarName("Out"))) {
|
|
|
|
|
ret[framework::GradVarName("Out")] = framework::GradVarName("X");
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -316,4 +333,5 @@ class ElementwiseOpInplace : public framework::InplaceInToOut {
|
|
|
|
|
op_type##GradMaker, \
|
|
|
|
|
::paddle::operators::ElementwiseOpInplace); \
|
|
|
|
|
REGISTER_OPERATOR(op_type##_grad, \
|
|
|
|
|
::paddle::operators::ElementwiseOpExplicitGrad)
|
|
|
|
|
::paddle::operators::ElementwiseOpExplicitGrad, \
|
|
|
|
|
::paddle::operators::ElementwiseGradOpInplace)
|
|
|
|
|