From 102b74682bacf1bd4a8d3fb3946e4480564aae95 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Sat, 28 Nov 2020 17:59:55 +0800 Subject: [PATCH] fix dnnl binary broadcast --- .../cpu/arithmetic_cpu_kernel.cc | 2 +- .../cpu/mkldnn/mkl_cpu_kernel.cc | 46 +++++++++++++++++++ .../cpu/mkldnn/mkl_cpu_kernel.h | 2 + .../cpu/mkldnn/mul_cpu_kernel.cc | 44 +----------------- .../cpu/mkldnn/tensoradd_cpu_kernel.cc | 44 +----------------- tests/st/ops/cpu/test_mul_op.py | 9 ++++ tests/st/ops/cpu/test_tensoradd.py | 9 ++++ 7 files changed, 69 insertions(+), 87 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index c44a9ae002..1030a6163d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -181,7 +181,7 @@ void ArithmeticCPUKernel::LaunchLess(const std::vector &inputs, cons T *input2 = reinterpret_cast(inputs[1]->addr); bool *output = reinterpret_cast(outputs[0]->addr); - size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; + size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(bool)) : 1; auto max_thread_num = std::thread::hardware_concurrency(); size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc index 0771f82d45..e2728d1702 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc @@ -67,6 +67,52 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa } } +bool MKLCPUKernel::BinaryBroadCast(std::vector *src0_shape, std::vector *src1_shape, + std::vector *dst_shape) { + MS_EXCEPTION_IF_NULL(src0_shape); + MS_EXCEPTION_IF_NULL(src1_shape); + MS_EXCEPTION_IF_NULL(dst_shape); + bool need_swap = false; + if (dst_shape->size() == 0) { + dst_shape->emplace_back(1); + src0_shape->emplace_back(1); + src1_shape->emplace_back(1); + } + MS_LOG(DEBUG) << "Binary broadcast in: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape; + if (src0_shape->size() != dst_shape->size()) { + need_swap = true; + for (size_t i = src0_shape->size(); i < dst_shape->size(); ++i) { + src0_shape->insert(src0_shape->begin(), 1); + } + } else if (src1_shape->size() != dst_shape->size()) { + for (size_t i = src1_shape->size(); i < dst_shape->size(); ++i) { + src1_shape->insert(src1_shape->begin(), 1); + } + } + if (src0_shape->size() == src1_shape->size()) { + bool visit_src0 = false; + bool visit_src1 = false; + for (size_t i = 0; i < src0_shape->size(); ++i) { + if (src0_shape->at(i) != src1_shape->at(i)) { + if (src0_shape->at(i) == 1 && !visit_src1) { + need_swap = true; + visit_src0 = true; + } else if (src1_shape->at(i) == 1 && !visit_src0) { + need_swap = false; + visit_src1 = true; + } else { + MS_LOG(EXCEPTION) << "Invalid broadcast! " << *src0_shape << " vs " << *src1_shape; + } + } + } + } else { + MS_LOG(EXCEPTION) << "Invalid broadcast! src0: " << *src0_shape << " src1: " << *src1_shape + << " dst: " << *dst_shape; + } + MS_LOG(DEBUG) << "Binary broadcast out: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape; + return need_swap; +} + dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const { dnnl::memory::format_tag mem_tag; auto dim_size = dims.size(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h index 7f145c7116..a92bc22885 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h @@ -32,6 +32,8 @@ class MKLCPUKernel : public CPUKernel { ~MKLCPUKernel() override = default; protected: + bool BinaryBroadCast(std::vector *src0_shape, std::vector *src1_shape, + std::vector *dst_shape); void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector &src_shape, const std::vector &kernel_size, int stride, std::vector *padding_l, std::vector *padding_r); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc index 243a003ecf..a7855f024c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc @@ -25,49 +25,7 @@ void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - if (dst_shape.size() == 0) { - dst_shape.emplace_back(1); - src0_shape.emplace_back(1); - src1_shape.emplace_back(1); - } - size_t src0_length = 1; - size_t src1_length = 1; - for (size_t i = 0; i < src0_shape.size(); ++i) { - src0_length = src0_length * src0_shape[i]; - } - for (size_t i = 0; i < src1_shape.size(); ++i) { - src1_length = src1_length * src1_shape[i]; - } - if (src1_shape.size() != src0_shape.size()) { - if (src0_length == 1 && src0_shape.size() != dst_shape.size()) { - need_swap_ = true; - for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) { - src0_shape.emplace_back(1); - } - } else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) { - for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { - src1_shape.emplace_back(1); - } - } else { - MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; - } - } else { - bool visit_src0 = false; - bool visit_src1 = false; - for (size_t i = 0; i < src0_shape.size(); ++i) { - if (src0_shape[i] != src1_shape[i]) { - if (src0_shape[i] == 1 && !visit_src1) { - need_swap_ = true; - visit_src0 = true; - } else if (src1_shape[i] == 1 && !visit_src0) { - need_swap_ = false; - visit_src1 = true; - } else { - MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; - } - } - } - } + need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape); dnnl::memory::desc src0_desc; dnnl::memory::desc src1_desc; if (need_swap_) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc index fc09ebbbd9..0e6330192e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.cc @@ -25,49 +25,7 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - if (dst_shape.size() == 0) { - dst_shape.emplace_back(1); - src0_shape.emplace_back(1); - src1_shape.emplace_back(1); - } - size_t src0_length = 1; - size_t src1_length = 1; - for (size_t i = 0; i < src0_shape.size(); ++i) { - src0_length = src0_length * src0_shape[i]; - } - for (size_t i = 0; i < src1_shape.size(); ++i) { - src1_length = src1_length * src1_shape[i]; - } - if (src1_shape.size() != src0_shape.size()) { - if (src0_length == 1 && src0_shape.size() != dst_shape.size()) { - need_swap_ = true; - for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) { - src0_shape.emplace_back(1); - } - } else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) { - for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { - src1_shape.emplace_back(1); - } - } else { - MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; - } - } else { - bool visit_src0 = false; - bool visit_src1 = false; - for (size_t i = 0; i < src0_shape.size(); ++i) { - if (src0_shape[i] != src1_shape[i]) { - if (src0_shape[i] == 1 && !visit_src1) { - need_swap_ = true; - visit_src0 = true; - } else if (src1_shape[i] == 1 && !visit_src0) { - need_swap_ = false; - visit_src1 = true; - } else { - MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; - } - } - } - } + need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape); dnnl::memory::desc src0_desc; dnnl::memory::desc src1_desc; if (need_swap_) { diff --git a/tests/st/ops/cpu/test_mul_op.py b/tests/st/ops/cpu/test_mul_op.py index c37aae6888..91e9a0b5fc 100644 --- a/tests/st/ops/cpu/test_mul_op.py +++ b/tests/st/ops/cpu/test_mul_op.py @@ -47,6 +47,8 @@ def test_mul(): y2 = Tensor(2, mstype.float32) x3 = Tensor(2, mstype.float32) y3 = Tensor(2, mstype.float32) + x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32)) + y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32)) mul = Net() out = mul(x0, y0).asnumpy() exp = x0.asnumpy() * y0.asnumpy() @@ -75,3 +77,10 @@ def test_mul(): err = np.ones(shape=exp.shape) * 1.0e-5 assert np.all(diff < err) assert out.shape == exp.shape + + out = mul(x4, y4).asnumpy() + exp = x4.asnumpy() * y4.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape diff --git a/tests/st/ops/cpu/test_tensoradd.py b/tests/st/ops/cpu/test_tensoradd.py index f0373370ed..ee6dfff67b 100644 --- a/tests/st/ops/cpu/test_tensoradd.py +++ b/tests/st/ops/cpu/test_tensoradd.py @@ -45,6 +45,8 @@ def test_tensor_add(): y2 = Tensor(2, mstype.float32) x3 = Tensor(2, mstype.float32) y3 = Tensor(2, mstype.float32) + x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32)) + y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32)) add = TensorAdd() out = add(x0, y0).asnumpy() exp = x0.asnumpy() + y0.asnumpy() @@ -73,3 +75,10 @@ def test_tensor_add(): err = np.ones(shape=exp.shape) * 1.0e-5 assert np.all(diff < err) assert out.shape == exp.shape + + out = add(x4, y4).asnumpy() + exp = x4.asnumpy() + y4.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape