|
|
|
@ -141,8 +141,6 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
@ -150,13 +148,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (dx != nullptr) dx->ShareDataWith(*dout);
|
|
|
|
|
if (dy == nullptr) return;
|
|
|
|
|
|
|
|
|
|
if (x->dims() == y->dims()) {
|
|
|
|
|
const framework::DDim& x_dim = dout->dims();
|
|
|
|
|
framework::DDim y_dim = dy->dims();
|
|
|
|
|
if (x_dim == y_dim) {
|
|
|
|
|
dy->ShareDataWith(*dout);
|
|
|
|
|
} else {
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
// Perform reduction to dout to calculate dy
|
|
|
|
|
const framework::DDim& x_dim = x->dims();
|
|
|
|
|
framework::DDim y_dim = y->dims();
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
|
|
|
|
|
y_dim = trim_trailing_singular_dims(y_dim);
|
|
|
|
|