|
|
|
@ -108,7 +108,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Rank of first input must >= rank of second input.")
|
|
|
|
|
|
|
|
|
|
if (x_dims == y_dims || product(y_dims) == 1) {
|
|
|
|
|
if (x_dims == y_dims) {
|
|
|
|
|
functor f;
|
|
|
|
|
f.template Run<Place, T>(x, y, z, ctx);
|
|
|
|
|
return;
|
|
|
|
@ -174,12 +174,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (product(y_dims) == 1) {
|
|
|
|
|
functor1 f;
|
|
|
|
|
f(place, x, y, out, dx, dy, dout);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
|
|
|
|
|