|
|
|
@ -340,6 +340,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (y_dims.size() == 1 && y_dims[0] == 1) {
|
|
|
|
|
// y is a scalar
|
|
|
|
|
auto extended_dims = framework::vectorize(x_dims);
|
|
|
|
|
extended_dims.push_back(1);
|
|
|
|
|
x_dims = framework::make_ddim(extended_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
|
|
|
|
@ -378,6 +385,13 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (y_dims.size() == 1 && y_dims[0] == 1) {
|
|
|
|
|
// y is a scalar
|
|
|
|
|
auto extended_dims = framework::vectorize(x_dims);
|
|
|
|
|
extended_dims.push_back(1);
|
|
|
|
|
x_dims = framework::make_ddim(extended_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
|
|
|
|