|
|
|
@ -98,9 +98,22 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
|
|
|
|
|
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
|
|
|
|
|
IdentityGrad<T>());
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
|
|
|
|
|
if (dx)
|
|
|
|
|
dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
if (dy)
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(), dx->data<T>());
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(), dy->data<T>());
|
|
|
|
|
} else {
|
|
|
|
|
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
|
|
|
|
|
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
|
|
|
|
|
IdentityGrad<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|