From fc002405758ab856ff871218372cc9ab3b6a94f6 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Thu, 28 Jan 2021 12:57:09 +0100 Subject: [PATCH] A fix for oneDNN matmul kernel. Fixes issue #30309 (#30723) --- .../operators/mkldnn/matmul_mkldnn_op.cc | 28 +++++++++---------- .../unittests/mkldnn/test_matmul_mkldnn_op.py | 22 ++++++++++++--- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index fb856d9740..3ef9d88e4e 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -188,34 +188,34 @@ class MatMulFactory { memory::dims strides_y; std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); - const auto x_bs = mat_dim_x.batch_size_; - const auto y_bs = mat_dim_y.batch_size_; + auto x_bs = mat_dim_x.batch_size_; + auto y_bs = mat_dim_y.batch_size_; PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false, platform::errors::InvalidArgument( "If batch sizes of X and Y are positive," "they have to be equal.")); - // Store 1 if both batches are zero, otherwise save the nonzero batch - const memory::dim BS = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; + memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; const memory::dim M = mat_dim_x.height_; const memory::dim N = mat_dim_y.width_; const memory::dim K = mat_dim_x.width_; batch_size_ = 1; - auto b = BS; - if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { + if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { auto& x_dims = ctx.Input("X")->dims(); auto& y_dims = ctx.Input("Y")->dims(); batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0]; - b = BS / batch_size_; + x_bs /= batch_size_; + y_bs /= batch_size_; + out_bs /= batch_size_; } - memory::dims x_dims = {b, M, K}; - memory::dims y_dims = {b, K, N}; - memory::dims out_dims = {b, M, N}; + memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; + memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; + memory::dims out_dims = {out_bs, M, N}; - x_offset_ = b * M * K * sizeof(XT); - y_offset_ = b * K * N * sizeof(YT); - out_offset_ = b * M * N * sizeof(OT); + x_offset_ = x_bs * M * K * sizeof(XT); + y_offset_ = y_bs * K * N * sizeof(YT); + out_offset_ = out_bs * M * N * sizeof(OT); // Translate transA and transB if (strides_x.empty()) @@ -226,7 +226,7 @@ class MatMulFactory { : memory::dims{N * K, 1, K}; memory::dims out_strides = memory::dims{M * N, N, 1}; - CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides); + CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides); return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py index 9a5443eed1..2f557f0bf1 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -48,6 +48,20 @@ class TestDnnlMatMulOp(OpTest): self.check_output() +class TestDnnlMatMulOpMixedDims1(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((17, 2, 3)).astype("float32") + self.y = np.random.random((3, 4)).astype("float32") + self.out = np.matmul(self.x, self.y) + + +class TestDnnlMatMulOpMixedDims2(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((2, 3)).astype("float32") + self.y = np.random.random((17, 3, 4)).astype("float32") + self.out = np.matmul(self.x, self.y) + + class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp): def generate_data(self): self.x = np.random.random((17, 2, 3)).astype("float32") @@ -396,10 +410,10 @@ class TestMatMulOpTransposeReshapeBasicFloat( TestMatMulOpTransposeReshapeEmptyFloat): def generate_data(self): self.bs = 8 - self.x = np.random.random( - [self.bs, 12, 128, 128]).astype(self.data_type_) - self.y = np.random.random( - [self.bs, 12, 128, 64]).astype(self.data_type_) + self.x = np.random.random([self.bs, 12, 128, + 128]).astype(self.data_type_) + self.y = np.random.random([self.bs, 12, 128, + 64]).astype(self.data_type_) def init_params_and_out(self): self.transpose_out = [0, 2, 1, 3]