|
|
|
@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
int rankdiff = ctx.Input<Tensor>("X")->dims().size() -
|
|
|
|
|
ctx.Input<Tensor>("Y")->dims().size();
|
|
|
|
|
return (axis == -1) || (axis == rankdiff);
|
|
|
|
|
return (rankdiff == 0) || (axis == -1) || (axis == rankdiff);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx) &&
|
|
|
|
@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// If broadcasting is needed, use native implementation
|
|
|
|
|
auto CanMKLDNNElementwiseAddGradBeUsed = [&]() {
|
|
|
|
|
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims());
|
|
|
|
|
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx) &&
|
|
|
|
|