diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.cc new file mode 100644 index 0000000000..289ba8fe6f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +template +void BinaryCrossEntropyCpuKernel::LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss) { + if (input_size % 2 == 1) { + tmp_loss[0] += tmp_loss[input_size - 1]; + } + + for (int stride = input_size / 2; stride > 0; stride >>= 1) { + for (int i = 0; i < stride; i++) { + tmp_loss[i] += tmp_loss[i + stride]; + } + if (stride > 2 && stride % 2 == 1) { + tmp_loss[0] += tmp_loss[stride - 1]; + } + } + + loss[0] += tmp_loss[0]; + if (reduction == 1) { + loss[0] /= static_cast(input_size); + } +} + +template +void BinaryCrossEntropyCpuKernel::Launchkernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input_x = reinterpret_cast(inputs[0]->addr); + T *input_y = reinterpret_cast(inputs[1]->addr); + T *weight = reinterpret_cast(inputs[2]->addr); + T *loss = reinterpret_cast(outputs[0]->addr); + std::vector tmp_loss(input_size_); + + T epsilon = static_cast(1e-12); + T one = static_cast(1); + if (reduction_ == 0) { + for (size_t i = 0; i < input_size_; i++) { + T value = + -weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon)); + loss[i] = value; + } + } else { + for (size_t i = 0; i < input_size_; i++) { + T value = + -weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } + + if (reduction_ != 0) { + LaunchToScalar(input_size_, reduction_, loss, tmp_loss.data()); + } +} + +bool BinaryCrossEntropyCpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (input_size_ > 0) { + if (dtype_ == kNumberTypeFloat32) { + Launchkernel(inputs, workspace, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + Launchkernel(inputs, workspace, outputs); + } + } + return true; +} + +void BinaryCrossEntropyCpuKernel::InitKernel(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + string reduction = AnfAlgo::GetNodeAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.h new file mode 100644 index 0000000000..b610f2eab1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.h @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BinaryCrossEntropyCpuKernel : public CPUKernel { + public: + BinaryCrossEntropyCpuKernel() : input_size_(1), reduction_(1) {} + ~BinaryCrossEntropyCpuKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + template + void LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss); + template + void Launchkernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + TypeId dtype_{kTypeUnknown}; + size_t input_size_; + int reduction_; +}; +MS_REG_CPU_KERNEL(BinaryCrossEntropy, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + BinaryCrossEntropyCpuKernel); +MS_REG_CPU_KERNEL(BinaryCrossEntropy, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BinaryCrossEntropyCpuKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.cc new file mode 100644 index 0000000000..e793c3d95e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.h" + +namespace mindspore { +namespace kernel { +template +void BinaryCrossEntropyGradCpuKernel::Launchkernel(const std::vector &inputs, + const std::vector &outputs) { + T *input_x = reinterpret_cast(inputs[0]->addr); + T *input_y = reinterpret_cast(inputs[1]->addr); + T *dloss = reinterpret_cast(inputs[2]->addr); + T *weight = reinterpret_cast(inputs[3]->addr); + T *dx = reinterpret_cast(outputs[0]->addr); + + T epsilon = static_cast(1e-12); + T one = static_cast(1); + if (reduction_ == 0) { + for (size_t i = 0; i < input_size_; i++) { + T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon; + T value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } else { + T dloss1 = dloss[0]; + if (reduction_ == 1) { + dloss1 = dloss[0] / static_cast(input_size_); + } + for (size_t i = 0; i < input_size_; i++) { + T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon; + T value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } + } +} + +bool BinaryCrossEntropyGradCpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (input_size_ > 0) { + if (dtype_ == kNumberTypeFloat32) { + Launchkernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + Launchkernel(inputs, outputs); + } + } + return true; +} + +void BinaryCrossEntropyGradCpuKernel::InitKernel(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + string reduction = AnfAlgo::GetNodeAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.h new file mode 100644 index 0000000000..c24c23e4bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BinaryCrossEntropyGradCpuKernel : public CPUKernel { + public: + BinaryCrossEntropyGradCpuKernel() : input_size_(1), reduction_(1) {} + ~BinaryCrossEntropyGradCpuKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + template + void Launchkernel(const std::vector &inputs, const std::vector &outputs); + + TypeId dtype_{kTypeUnknown}; + size_t input_size_; + int reduction_; +}; +MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + BinaryCrossEntropyGradCpuKernel); +MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BinaryCrossEntropyGradCpuKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index cf19393873..2b57031b3c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -4482,7 +4482,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer): Otherwise, the output is a scalar. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import mindspore diff --git a/tests/st/ops/cpu/test_binary_cross_entropy_op.py b/tests/st/ops/cpu/test_binary_cross_entropy_op.py new file mode 100644 index 0000000000..d0c774588c --- /dev/null +++ b/tests/st/ops/cpu/test_binary_cross_entropy_op.py @@ -0,0 +1,141 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.BinaryCrossEntropy = P.BinaryCrossEntropy(reduction) + + def construct(self, x, y, weight): + return self.BinaryCrossEntropy(x, y, weight) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "none" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.09555826, 1.2861121, 0.03518666, 0.6969416, 0.24313456, 0.99062896, + 0.19205657, 0.5465214, 0.36964455, 0.21999404, 2.2953863, 2.2566645, + 1.5803775, 1.3266402, 0.9883408, 1.2997618, 0.05439841, 0.14389999, + 0.03405444, 0.23934692] + assert np.allclose(loss.asnumpy(), expect) + +def test_binary_cross_entropy_loss_mean(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "mean" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.7447324991226196] + assert loss.asnumpy() == expect + +def test_binary_cross_entropy_loss_sum(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "sum" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [14.894649505615234] + assert loss.asnumpy() == expect + +def test_binary_cross_entropy_loss_16(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float16) + target = np.random.rand(20).astype(np.float16) + weight = np.random.rand(20).astype(np.float16) + reduction = "none" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.09552, 1.28613, 0.0351868, 0.696777, 0.243164, 0.990234, + 0.192139, 0.546875, 0.370117, 0.219971, 2.29492, 2.25391, + 1.58105, 1.32812, 0.987305, 1.30078, 0.0544434, 0.143921, + 0.0340576, 0.239258] + assert np.allclose(loss.asnumpy(), expect) + +def test_binary_cross_entropy_loss_mean_16(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float16) + target = np.random.rand(20).astype(np.float16) + weight = np.random.rand(20).astype(np.float16) + reduction = "mean" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.74462890625] + assert loss.asnumpy() == expect + +def test_binary_cross_entropy_loss_sum_16(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float16) + target = np.random.rand(20).astype(np.float16) + weight = np.random.rand(20).astype(np.float16) + reduction = "sum" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [14.890625] + assert loss.asnumpy() == expect + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens, weight): + gout = self.grad(self.network)(x1, x2, sens, weight) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss_grad(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + sens = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "none" + grad = Grad(Net(reduction)) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens), Tensor(weight)) + + dx1_expect = [-4.80516590e-02, 2.32625079e+00, 6.38972521e-02, 3.13642323e-01, + -1.65661633e-01, -1.71821892e+00, -1.13685496e-01, 1.26669514e+00, + 1.47891801e-03, 5.83921909e-01, -2.17992840e+01, 4.21899414e+00, + 2.85430793e-02, -3.21346498e+00, -2.22674108e+00, -2.80453944e+00, + -1.19787852e-04, 2.48514321e-02, -1.66696273e-02, -2.71965731e-02] + + assert np.allclose(dx[0].asnumpy(), dx1_expect)