|
|
|
@ -144,41 +144,16 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
// skip out, x, y
|
|
|
|
|
auto* out = dout;
|
|
|
|
|
auto *x = dout, *y = dout;
|
|
|
|
|
|
|
|
|
|
if (dx != nullptr) {
|
|
|
|
|
// In fact, we can just share memory, but it may cause a bug of memory
|
|
|
|
|
// optimizer
|
|
|
|
|
// dx->ShareDataWith(*dout);
|
|
|
|
|
framework::TensorCopy(*dout, ctx.GetPlace(),
|
|
|
|
|
ctx.template device_context<DeviceContext>(), dx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dy == nullptr) return;
|
|
|
|
|
|
|
|
|
|
const framework::DDim& x_dim = dout->dims();
|
|
|
|
|
framework::DDim y_dim = dy->dims();
|
|
|
|
|
if (x_dim == y_dim) {
|
|
|
|
|
// dy->ShareDataWith(*dout);
|
|
|
|
|
framework::TensorCopy(*dout, ctx.GetPlace(),
|
|
|
|
|
ctx.template device_context<DeviceContext>(), dy);
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
|
|
|
|
|
dy != nullptr && (dx->dims() == dy->dims())) {
|
|
|
|
|
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
} else {
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
// Perform reduction to dout to calculate dy
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
|
|
|
|
|
y_dim = trim_trailing_singular_dims(y_dim);
|
|
|
|
|
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
|
|
|
|
|
|
|
|
|
|
auto& device =
|
|
|
|
|
*(ctx.template device_context<DeviceContext>().eigen_device());
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
|
|
|
|
|
auto eigen_dout = framework::EigenTensor<T, 3>::From(
|
|
|
|
|
*dout, framework::make_ddim({pre, n, post}));
|
|
|
|
|
auto eigen_dy =
|
|
|
|
|
framework::EigenTensor<T, 1>::From(*dy, framework::make_ddim({n}));
|
|
|
|
|
eigen_dy.device(device) = eigen_dout.sum(
|
|
|
|
|
framework::EigenDim<2>::From(framework::make_ddim({0, 2})));
|
|
|
|
|
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
|
|
|
|
|
dy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|