[oneDNN] Fix to elementwise_add grad (#24639)

v1.8
Jacek Czaja 5 years ago committed by GitHub
parent 824572c144
commit ca68b13f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
int axis = ctx.Attr<int>("axis");
int rankdiff = ctx.Input<Tensor>("X")->dims().size() -
ctx.Input<Tensor>("Y")->dims().size();
return (axis == -1) || (axis == rankdiff);
return (rankdiff == 0) || (axis == -1) || (axis == rankdiff);
};
if (platform::CanMKLDNNBeUsed(ctx) &&
@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation
auto CanMKLDNNElementwiseAddGradBeUsed = [&]() {
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims());
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
};
if (platform::CanMKLDNNBeUsed(ctx) &&

@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
in->set_format(out->format());
};
// TODO(jczaja): Double check if vcopy works for blocked data
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),

Loading…
Cancel
Save