[oneDNN] Fix to elementwise_add grad (#24639)

v1.8
Jacek Czaja 5 years ago committed by GitHub
parent 824572c144
commit ca68b13f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
int rankdiff = ctx.Input<Tensor>("X")->dims().size() - int rankdiff = ctx.Input<Tensor>("X")->dims().size() -
ctx.Input<Tensor>("Y")->dims().size(); ctx.Input<Tensor>("Y")->dims().size();
return (axis == -1) || (axis == rankdiff); return (rankdiff == 0) || (axis == -1) || (axis == rankdiff);
}; };
if (platform::CanMKLDNNBeUsed(ctx) && if (platform::CanMKLDNNBeUsed(ctx) &&
@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation // If broadcasting is needed, use native implementation
auto CanMKLDNNElementwiseAddGradBeUsed = [&]() { auto CanMKLDNNElementwiseAddGradBeUsed = [&]() {
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
auto dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims());
}; };
if (platform::CanMKLDNNBeUsed(ctx) && if (platform::CanMKLDNNBeUsed(ctx) &&

@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
in->set_format(out->format()); in->set_format(out->format());
}; };
// TODO(jczaja): Double check if vcopy works for blocked data
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx); auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) { if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),

Loading…
Cancel
Save