diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index fa5704057b..2eac461911 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -20,39 +20,39 @@ namespace mindspore { namespace kernel { -namespace { template -void Add(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { +void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { - out[i] = input1[i] + (is_number ? *input2 : input2[i]); + out[i] = input1[i] + input2[i]; } } template -void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { +void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { - out[i] = input1[i] - (is_number ? *input2 : input2[i]); + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] - input2[idx[1]]; } } template -void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { +void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { - out[i] = input1[i] * (is_number ? *input2 : input2[i]); + out[i] = input1[i] * input2[i]; } } template -void Div(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { +void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { - auto div_number = is_number ? *input2 : input2[i]; + auto div_number = input2[i]; if (div_number == 0) { MS_LOG(EXCEPTION) << "Cannot divided by 0!"; } out[i] = input1[i] / div_number; } } -} // namespace void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); @@ -67,21 +67,20 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = DIV; } - auto shape0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - if (shape1.size() == 0) { - is_number_ = true; - } else { - is_number_ = false; - if (shape0.size() != shape1.size()) { - MS_LOG(EXCEPTION) << "Input0 and input1 must has the same shape"; - } - for (size_t i = 0; i < shape0.size(); ++i) { - if (shape0[i] != shape1[i]) { - MS_LOG(EXCEPTION) << "Input0 and input1 must has the same shape"; - } - } + input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_shape1_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + size_t l = input_shape0_.size(); + for (size_t i = 0; i < output_shape_.size() - l; ++i) { + input_shape0_.insert(input_shape0_.begin(), 1); + } + l = input_shape1_.size(); + for (size_t i = 0; i < output_shape_.size() - l; ++i) { + input_shape1_.insert(input_shape1_.begin(), 1); } + CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_); + CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_); + CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; @@ -103,14 +102,43 @@ bool ArithmeticCPUKernel::Launch(const std::vector &inputs, return true; } +void ArithmeticCPUKernel::GenIndex(size_t num, std::vector *idx) { + std::vector tmp; + for (size_t i = 0; i < output_shape_.size() - 1; ++i) { + if (output_element_num_[i] > num) { + tmp.push_back(0); + } else { + tmp.push_back(num / output_element_num_[i]); + num %= output_element_num_[i]; + } + } + tmp.push_back(num); + size_t idx0 = 0; + size_t idx1 = 0; + for (size_t k = 0; k < tmp.size() - 1; ++k) { + if (input_shape0_[k] > 1) { + idx0 += tmp[k] * input_element_num0_[k]; + } + if (input_shape1_[k] > 1) { + idx1 += tmp[k] * input_element_num1_[k]; + } + } + if (input_shape0_[tmp.size() - 1] > 1) { + idx0 += tmp[tmp.size() - 1]; + } + if (input_shape1_[tmp.size() - 1] > 1) { + idx1 += tmp[tmp.size() - 1]; + } + idx->push_back(idx0); + idx->push_back(idx1); +} template void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { T *input1 = reinterpret_cast(inputs[0]->addr); T *input2 = reinterpret_cast(inputs[1]->addr); T *output = reinterpret_cast(outputs[0]->addr); - auto lens = inputs[0]->size / sizeof(T); + auto lens = outputs[0]->size / sizeof(T); MS_LOG(INFO) << "lens=" << lens; - const size_t thread_num = 24; std::vector threads; threads.reserve(thread_num); @@ -119,13 +147,13 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, co while (start < lens) { size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); if (operate_type_ == ADD) { - threads.emplace_back(std::thread(Add, input1, input2, output, start, end, is_number_)); + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Add, this, input1, input2, output, start, end)); } else if (operate_type_ == SUB) { - threads.emplace_back(std::thread(Sub, input1, input2, output, start, end, is_number_)); + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Sub, this, input1, input2, output, start, end)); } else if (operate_type_ == MUL) { - threads.emplace_back(std::thread(Mul, input1, input2, output, start, end, is_number_)); + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul, this, input1, input2, output, start, end)); } else if (operate_type_ == DIV) { - threads.emplace_back(std::thread(Div, input1, input2, output, start, end, is_number_)); + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div, this, input1, input2, output, start, end)); } start += once_compute_size; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 20ea77e350..07d984528b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -36,7 +36,21 @@ class ArithmeticCPUKernel : public CPUKernel { void LaunchKernel(const std::vector &inputs, const std::vector &outputs); private: - bool is_number_{false}; + void GenIndex(size_t num, std::vector *tmp); + template + void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void Add(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); + std::vector input_shape0_; + std::vector input_shape1_; + std::vector input_element_num0_; + std::vector input_element_num1_; + std::vector output_shape_; + std::vector output_element_num_; OperateType operate_type_{ADD}; TypeId dtype_{kTypeUnknown}; }; diff --git a/tests/st/ops/cpu/test_arithmetic_op.py b/tests/st/ops/cpu/test_arithmetic_op.py index d3b77843b2..069a40653c 100644 --- a/tests/st/ops/cpu/test_arithmetic_op.py +++ b/tests/st/ops/cpu/test_arithmetic_op.py @@ -37,10 +37,10 @@ class SubNet(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_sub(): - x = np.ones([2, 3, 4, 4]).astype(np.int32) - y = 1 + x = np.random.rand(2, 3, 4, 4).astype(np.float32) + y = np.random.rand(4, 1).astype(np.float32) net = SubNet() - output = net(Tensor(x), Tensor(y, mindspore.int32)) - expect_output = np.zeros([2, 3, 4, 4]).astype(np.int) - print(output) + output = net(Tensor(x), Tensor(y, mindspore.float32)) + expect_output = x - y assert np.all(output.asnumpy() == expect_output) +test_sub()