From 6d8fb56ea821431e1e51b8d1b50f1079d7052143 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Mon, 16 Nov 2020 14:09:34 +0800 Subject: [PATCH] fix mul broadcast and tanhgrad bug --- .../cpu/eltwise_grad_cpu_kernel.cc | 4 ++-- .../cpu/mkldnn/mul_cpu_kernel.cc | 21 +++++++++++++++---- .../cpu/mkldnn/tensoradd_cpu_kernel.cc | 21 +++++++++++++++---- tests/st/ops/cpu/test_mul_op.py | 9 ++++++++ tests/st/ops/cpu/test_tensoradd.py | 9 ++++++++ 5 files changed, 54 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc index e7a91461ea..799260c3a4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc @@ -73,8 +73,8 @@ void EltWiseGradCPUKernel::SqrtGrad(const T *input1, const T *input2, T *out, si template void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { - T tmp = (1 - input1[i]); - out[i] = input2[i] * tmp * tmp; + T tmp = input1[i] * input1[i]; + out[i] = input2[i] * (1 - tmp); } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc index 6c2af2f87d..243a003ecf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc @@ -25,14 +25,27 @@ void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + if (dst_shape.size() == 0) { + dst_shape.emplace_back(1); + src0_shape.emplace_back(1); + src1_shape.emplace_back(1); + } + size_t src0_length = 1; + size_t src1_length = 1; + for (size_t i = 0; i < src0_shape.size(); ++i) { + src0_length = src0_length * src0_shape[i]; + } + for (size_t i = 0; i < src1_shape.size(); ++i) { + src1_length = src1_length * src1_shape[i]; + } if (src1_shape.size() != src0_shape.size()) { - if (src0_shape.size() == 0) { + if (src0_length == 1 && src0_shape.size() != dst_shape.size()) { need_swap_ = true; - for (size_t i = 0; i < src1_shape.size(); ++i) { + for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) { src0_shape.emplace_back(1); } - } else if (src1_shape.size() == 0) { - for (size_t i = 0; i < src0_shape.size(); ++i) { + } else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) { + for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { src1_shape.emplace_back(1); } } else { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc index 85e74e27d6..fc09ebbbd9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc @@ -25,14 +25,27 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + if (dst_shape.size() == 0) { + dst_shape.emplace_back(1); + src0_shape.emplace_back(1); + src1_shape.emplace_back(1); + } + size_t src0_length = 1; + size_t src1_length = 1; + for (size_t i = 0; i < src0_shape.size(); ++i) { + src0_length = src0_length * src0_shape[i]; + } + for (size_t i = 0; i < src1_shape.size(); ++i) { + src1_length = src1_length * src1_shape[i]; + } if (src1_shape.size() != src0_shape.size()) { - if (src0_shape.size() == 0) { + if (src0_length == 1 && src0_shape.size() != dst_shape.size()) { need_swap_ = true; - for (size_t i = 0; i < src1_shape.size(); ++i) { + for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) { src0_shape.emplace_back(1); } - } else if (src1_shape.size() == 0) { - for (size_t i = 0; i < src0_shape.size(); ++i) { + } else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) { + for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { src1_shape.emplace_back(1); } } else { diff --git a/tests/st/ops/cpu/test_mul_op.py b/tests/st/ops/cpu/test_mul_op.py index 88d1e71eef..c37aae6888 100644 --- a/tests/st/ops/cpu/test_mul_op.py +++ b/tests/st/ops/cpu/test_mul_op.py @@ -45,6 +45,8 @@ def test_mul(): y1 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)) x2 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)) y2 = Tensor(2, mstype.float32) + x3 = Tensor(2, mstype.float32) + y3 = Tensor(2, mstype.float32) mul = Net() out = mul(x0, y0).asnumpy() exp = x0.asnumpy() * y0.asnumpy() @@ -66,3 +68,10 @@ def test_mul(): err = np.ones(shape=exp.shape) * 1.0e-5 assert np.all(diff < err) assert out.shape == exp.shape + + out = mul(x3, y3).asnumpy() + exp = x3.asnumpy() * y3.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape diff --git a/tests/st/ops/cpu/test_tensoradd.py b/tests/st/ops/cpu/test_tensoradd.py index 6284c0dbbc..f0373370ed 100644 --- a/tests/st/ops/cpu/test_tensoradd.py +++ b/tests/st/ops/cpu/test_tensoradd.py @@ -43,6 +43,8 @@ def test_tensor_add(): y1 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)) x2 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)) y2 = Tensor(2, mstype.float32) + x3 = Tensor(2, mstype.float32) + y3 = Tensor(2, mstype.float32) add = TensorAdd() out = add(x0, y0).asnumpy() exp = x0.asnumpy() + y0.asnumpy() @@ -64,3 +66,10 @@ def test_tensor_add(): err = np.ones(shape=exp.shape) * 1.0e-5 assert np.all(diff < err) assert out.shape == exp.shape + + out = add(x3, y3).asnumpy() + exp = x3.asnumpy() + y3.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape