@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License . */
# include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
# include <memory>
# include <string>
# include "paddle/fluid/operators/elementwise/elementwise_op.h"
@ -43,6 +44,30 @@ class ElementwiseMulOpMaker : public ElementwiseOpMaker {
virtual std : : string GetEquation ( ) const { return " Out = X \\ \\ odot Y " ; }
} ;
class ElementwiseMulDoubleGradDescMaker
: public framework : : SingleGradOpDescMaker {
public :
using framework : : SingleGradOpDescMaker : : SingleGradOpDescMaker ;
protected :
std : : unique_ptr < framework : : OpDesc > Apply ( ) const override {
std : : unique_ptr < framework : : OpDesc > op ( new framework : : OpDesc ( ) ) ;
op - > SetType ( " elementwise_mul_grad_grad " ) ;
op - > SetInput ( " X " , Input ( " X " ) ) ;
op - > SetInput ( " Y " , Input ( " Y " ) ) ;
op - > SetInput ( " DOut " , Input ( framework : : GradVarName ( " Out " ) ) ) ;
op - > SetInput ( " DDX " , OutputGrad ( framework : : GradVarName ( " X " ) ) ) ;
op - > SetInput ( " DDY " , OutputGrad ( framework : : GradVarName ( " Y " ) ) ) ;
op - > SetAttrMap ( Attrs ( ) ) ;
op - > SetOutput ( " DDOut " , InputGrad ( framework : : GradVarName ( " Out " ) ) ) ;
op - > SetOutput ( framework : : GradVarName ( " X " ) , InputGrad ( " X " ) ) ;
op - > SetOutput ( framework : : GradVarName ( " Y " ) , InputGrad ( " Y " ) ) ;
return op ;
}
} ;
} // namespace operators
} // namespace paddle
@ -50,7 +75,9 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR ( elementwise_mul , ops : : ElementwiseOp ,
ops : : ElementwiseMulOpMaker , ops : : ElementwiseOpInferVarType ,
ops : : ElementwiseMulOpGradDescMaker ) ;
REGISTER_OPERATOR ( elementwise_mul_grad , ops : : ElementwiseOpGrad ) ;
REGISTER_OPERATOR ( elementwise_mul_grad , ops : : ElementwiseOpGrad ,
ops : : ElementwiseMulDoubleGradDescMaker ) ;
REGISTER_OPERATOR ( elementwise_mul_grad_grad , ops : : ElementwiseOpDoubleGrad ) ;
REGISTER_OP_CPU_KERNEL (
elementwise_mul ,
@ -64,3 +91,13 @@ REGISTER_OP_CPU_KERNEL(
ops : : ElementwiseMulGradKernel < paddle : : platform : : CPUDeviceContext , double > ,
ops : : ElementwiseMulGradKernel < paddle : : platform : : CPUDeviceContext , int > ,
ops : : ElementwiseMulGradKernel < paddle : : platform : : CPUDeviceContext , int64_t > ) ;
REGISTER_OP_CPU_KERNEL (
elementwise_mul_grad_grad ,
ops : : ElementwiseMulDoubleGradKernel < paddle : : platform : : CPUDeviceContext ,
float > ,
ops : : ElementwiseMulDoubleGradKernel < paddle : : platform : : CPUDeviceContext ,
double > ,
ops : : ElementwiseMulDoubleGradKernel < paddle : : platform : : CPUDeviceContext ,
int > ,
ops : : ElementwiseMulDoubleGradKernel < paddle : : platform : : CPUDeviceContext ,
int64_t > ) ;