|
|
|
@ -172,34 +172,50 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// (2) dy = dout * ddx
|
|
|
|
|
// (3) ddout = ddx * y
|
|
|
|
|
// (4) ddout = ddout + dx
|
|
|
|
|
// (5) dx = dout *ddy
|
|
|
|
|
// (5) dx = dout * ddy
|
|
|
|
|
if (ddout) {
|
|
|
|
|
// use dx to save memory, other than alloc tmp tensor
|
|
|
|
|
Tensor* ddout_tmp = dx;
|
|
|
|
|
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, x, &ddy_safe, ddout_tmp);
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
// NOTE: in the following ElemwiseGradCompute, for the
|
|
|
|
|
// first output tensor is nullptr, the branch to calculate first
|
|
|
|
|
// output tensor will not be activated, DivGradDx function will not
|
|
|
|
|
// be called and can be ignored, the first branch has little effect
|
|
|
|
|
// on running speed.
|
|
|
|
|
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
|
|
|
|
|
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy,
|
|
|
|
|
MulGradDX<T>(), MulGradDY<T>());
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, y, ddout);
|
|
|
|
|
|
|
|
|
|
auto& place =
|
|
|
|
|
*ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
|
|
|
|
|
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(*ddout_tmp);
|
|
|
|
|
ddout_t.device(place) = ddout_t + ddout_tmp_t;
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, dout, &ddy_safe, dx);
|
|
|
|
|
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
|
|
|
|
|
if (ddout->numel() > ddx->numel()) {
|
|
|
|
|
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
|
|
|
|
|
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX<T>(),
|
|
|
|
|
MulGradDY<T>());
|
|
|
|
|
|
|
|
|
|
Tensor ddout_tmp;
|
|
|
|
|
ddout_tmp.mutable_data<T>(ddout->dims(), ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, y, &ddx_safe, ddout);
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, x,
|
|
|
|
|
&ddout_tmp);
|
|
|
|
|
|
|
|
|
|
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
|
|
|
|
|
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(ddout_tmp);
|
|
|
|
|
ddout_t.device(place) = ddout_t + ddout_tmp_t;
|
|
|
|
|
} else {
|
|
|
|
|
// use dx to save memory, other than alloc tmp tensor
|
|
|
|
|
Tensor* ddout_tmp = dx;
|
|
|
|
|
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, x, &ddy_safe, ddout_tmp);
|
|
|
|
|
// NOTE: in the following ElemwiseGradCompute, for the
|
|
|
|
|
// first output tensor is nullptr, the branch to calculate first
|
|
|
|
|
// output tensor will not be activated, DivGradDx function will not
|
|
|
|
|
// be called and can be ignored, the first branch has little effect
|
|
|
|
|
// on running speed.
|
|
|
|
|
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
|
|
|
|
|
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy,
|
|
|
|
|
MulGradDX<T>(), MulGradDY<T>());
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, y, ddout);
|
|
|
|
|
|
|
|
|
|
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
|
|
|
|
|
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(*ddout_tmp);
|
|
|
|
|
ddout_t.device(place) = ddout_t + ddout_tmp_t;
|
|
|
|
|
default_elementwise_mul<DeviceContext, T>(ctx, dout, &ddy_safe, dx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"});
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|