diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc new file mode 100644 index 0000000000..05b1a79924 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h" +#include +#include +#include +#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" +#include "device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + size_t type_size = sizeof(float); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + workspace_size_list_.emplace_back(tensor_size); +} + +void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + dnnl::memory::dims mem_dims; + mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); + if (mem_dims.size() != 2) { + MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); + } + batch_size_ = shape[0]; + class_num_ = shape[1]; + if (batch_size_ == 0 || class_num_ == 0) { + MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; + } + dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); + + dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, mem_desc); + AddArgument(DNNL_ARG_DST, mem_desc); +} + +void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels, + float *output1, float *output2) const { + float epsilon = 1e-6; + for (size_t i = 0; i < batch_size_; ++i) { + output1[i] = 0; + float loss = 0.0; + for (size_t j = 0; j < class_num_; ++j) { + float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]); + output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j]; + loss += labels[i * class_num_ + j] * logit; + } + output1[i] = -loss; + } +} + +bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (inputs.empty() || workspace.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + size_t batch_float_size = batch_size_ * sizeof(float); + size_t batch_class_float_size = class_num_ * batch_float_size; + if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || + inputs[1]->size != batch_class_float_size) { + MS_LOG(EXCEPTION) << "error input data size!"; + } + if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { + MS_LOG(EXCEPTION) << "error output data size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); + ExecutePrimitive(); + auto labels = reinterpret_cast(inputs[1]->addr); + auto logits = reinterpret_cast(workspace[0]->addr); + auto output1 = reinterpret_cast(outputs[0]->addr); + auto output2 = reinterpret_cast(outputs[1]->addr); + ForwardPostExecute(logits, labels, output1, output2); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h new file mode 100644 index 0000000000..f663508059 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ + +#include +#include +#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { + public: + SoftmaxCrossEntropyWithLogitsCPUKernel() = default; + ~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + void InitInputOutputSize(const CNodePtr &kernel_node) override; + + private: + void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const; + size_t class_num_{0}; + size_t batch_size_{0}; +}; +MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SoftmaxCrossEntropyWithLogitsCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc new file mode 100644 index 0000000000..b12371c933 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "kernel/cpu/reduce_cpu_kernel.h" +#include "device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +const size_t kReduceTypeMax = 0; +const size_t kReduceTypeMean = 1; +const size_t kReduceTypeSum = 2; +const size_t kMaxDim = 100; +void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "ReduceMax") { + reduce_type_ = kReduceTypeMax; + } else if (kernel_name == "ReduceMean") { + reduce_type_ = kReduceTypeMean; + } else if (kernel_name == "ReduceSum") { + reduce_type_ = kReduceTypeSum; + } else { + MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; + } + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); + if (axis_addr->isa()) { + auto attr_axis = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); + if (attr_axis.size() > shape_.size()) { + MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size(); + } else if (attr_axis.empty()) { + axis_.push_back(shape_.size() - 1); + } else { + for (auto axis : attr_axis) { + if (IntToSize(axis) >= (shape_.size())) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + } + } + } else if (axis_addr->isa()) { + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + + if (axis >= 0 && IntToSize(axis) >= shape_.size()) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; + } + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] <= 0) { + MS_LOG(EXCEPTION) << "shape value is invalid."; + } + left_dims_ *= shape_[i]; + } + for (size_t i = 0; i < axis_.size(); ++i) { + stride_ *= shape_[axis_[i]]; + } + if (stride_ <= 0) { + MS_LOG(EXCEPTION) << "stride_ must greater than zero."; + } + left_dims_ = left_dims_ / stride_; +} +bool ReduceCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + size_t out_float_size = left_dims_ * sizeof(float); + size_t in_float_size = stride_ * out_float_size; + if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + } + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + int size = inputs[0]->size / sizeof(float); + std::vector new_input(IntToSize(size), 0.0); + std::vector transpose_axis; + for (size_t i = 0; i < shape_.size(); ++i) { + bool insert = true; + for (size_t j = 0; j < axis_.size(); ++j) { + if (axis_[j] == i) { + insert = false; + break; + } + } + if (insert) { + transpose_axis.push_back(i); + } + } + (void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end()); + Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]); + if (reduce_type_ == kReduceTypeMax) { + for (size_t i = 0; i < left_dims_; ++i) { + float value = new_input[i * stride_]; + for (size_t k = 0; k < stride_; ++k) { + if (value < new_input[i * stride_ + k]) { + value = new_input[i * stride_ + k]; + } + } + output[i] = value; + } + } else { + for (size_t i = 0; i < left_dims_; ++i) { + float value = 0.0; + for (size_t k = 0; k < stride_; ++k) { + value += new_input[i * stride_ + k]; + } + if (reduce_type_ == kReduceTypeMean) { + output[i] = value / stride_; + } else { + output[i] = value; + } + } + } + return true; +} +void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, float *output) { + int pos_array[kMaxDim]; + int size_offset[kMaxDim]; + size_offset[0] = size / SizeToInt(input_shape[0]); + for (int i = 1; i < shape_size; i++) { + size_offset[i] = size_offset[i - 1] / SizeToInt(input_shape[i]); + } + for (int position = 0; position < size; position += 1) { + int temp_position = position; + pos_array[0] = temp_position / size_offset[0]; + for (int i = 1; i < shape_size; i++) { + temp_position -= pos_array[i - 1] * size_offset[i - 1]; + pos_array[i] = temp_position / size_offset[i]; + } + int new_position = pos_array[SizeToInt(input_axis[shape_size - 1])]; + int new_position_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + new_position_size *= SizeToInt(input_shape[SizeToInt(input_axis[j + 1])]); + new_position += pos_array[SizeToInt(input_axis[j])] * new_position_size; + } + output[new_position] = input[position]; + } + return; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h new file mode 100644 index 0000000000..27d28ba3bd --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#include +#include +#include +#include "kernel/cpu/cpu_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ReduceCPUKernel : public CPUKernel { + public: + ReduceCPUKernel() = default; + ~ReduceCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void Transpose(const int size, const float *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, float *output); + size_t reduce_type_; + std::vector axis_; + std::vector shape_; + size_t left_dims_ = 1; + size_t stride_ = 1; +}; +MS_REG_CPU_KERNEL(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +MS_REG_CPU_KERNEL(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +MS_REG_CPU_KERNEL(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); + +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_reduce_op.py b/tests/st/ops/cpu/test_reduce_op.py new file mode 100644 index 0000000000..39b2d8fa14 --- /dev/null +++ b/tests/st/ops/cpu/test_reduce_op.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context +from mindspore.common.api import ms_function + +context.set_context(device_target="CPU") + + +class NetReduce(nn.Cell): + def __init__(self): + super(NetReduce, self).__init__() + self.axis0 = 0 + self.axis1 = 1 + self.axis2 = -1 + self.axis3 = (0, 1) + self.axis4 = (0, 1, 2) + self.reduce_mean = P.ReduceMean(False) + self.reduce_sum = P.ReduceSum(False) + self.reduce_max = P.ReduceMax(False) + + @ms_function + def construct(self, indice): + return (self.reduce_mean(indice, self.axis0), + self.reduce_mean(indice, self.axis1), + self.reduce_mean(indice, self.axis2), + self.reduce_mean(indice, self.axis3), + self.reduce_mean(indice, self.axis4), + self.reduce_sum(indice, self.axis0), + self.reduce_sum(indice, self.axis2), + self.reduce_max(indice, self.axis0), + self.reduce_max(indice, self.axis2)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_reduce(): + reduce = NetReduce() + indice = Tensor(np.array([ + [[0., 2., 1., 4., 0., 2.], [3., 1., 2., 2., 4., 0.]], + [[2., 0., 1., 5., 0., 1.], [1., 0., 0., 4., 4., 3.]], + [[4., 1., 4., 0., 0., 0.], [2., 5., 1., 0., 1., 3.]] + ]).astype(np.float32)) + output = reduce(indice) + print(output[0]) + print(output[1]) + print(output[2]) + print(output[3]) + print(output[4]) + print(output[5]) + print(output[6]) + print(output[7]) + print(output[8]) + expect_0 = np.array([[2., 1., 2., 3., 0., 1], [2., 2., 1., 2., 3., 2.]]).astype(np.float32) + expect_1 = np.array([[1.5, 1.5, 1.5, 3., 2., 1.], [1.5, 0., 0.5, 4.5, 2., 2.], [3., 3., 2.5, 0., 0.5, 1.5]]).astype( + np.float32) + expect_2 = np.array([[1.5, 2.], [1.5, 2.], [1.5, 2.]]).astype(np.float32) + expect_3 = np.array([2, 1.5, 1.5, 2.5, 1.5, 1.5]).astype(np.float32) + expect_4 = np.array([1.75]).astype(np.float32) + expect_5 = np.array([[6., 3., 6., 9., 0., 3.], [6., 6., 3., 6., 9., 6.]]).astype(np.float32) + expect_6 = np.array([[9., 12.], [9., 12.], [9., 12.]]).astype(np.float32) + expect_7 = np.array([[4., 2., 4., 5., 0., 2.], [3., 5., 2., 4., 4., 3.]]).astype(np.float32) + expect_8 = np.array([[4., 4.], [5., 4.], [4., 5.]]).astype(np.float32) + assert (output[0].asnumpy() == expect_0).all() + assert (output[1].asnumpy() == expect_1).all() + assert (output[2].asnumpy() == expect_2).all() + assert (output[3].asnumpy() == expect_3).all() + assert (output[4].asnumpy() == expect_4).all() + assert (output[5].asnumpy() == expect_5).all() + assert (output[6].asnumpy() == expect_6).all() + assert (output[7].asnumpy() == expect_7).all() + assert (output[8].asnumpy() == expect_8).all() + + +test_reduce() diff --git a/tests/st/ops/cpu/test_softmax_cross_entropy_with_logits_op.py b/tests/st/ops/cpu/test_softmax_cross_entropy_with_logits_op.py new file mode 100644 index 0000000000..79689b0b87 --- /dev/null +++ b/tests/st/ops/cpu/test_softmax_cross_entropy_with_logits_op.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor + + +class NetSoftmaxCrossEntropyWithLogits(nn.Cell): + def __init__(self): + super(NetSoftmaxCrossEntropyWithLogits, self).__init__() + self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False) + + def construct(self, logits, labels): + return self.loss(logits, labels) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softmax_cross_entropy_with_logits(): + logits = Tensor(np.array([[1, 1, 10], + [1, 10, 1], + [10, 1, 1]]).astype(np.float32)) + labels = Tensor(np.array([[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]).astype(np.float32)) + expect_loss = [0.00024673, 0.00024673, 0.00024673] + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + softmax_cross_entropy_with_logits = NetSoftmaxCrossEntropyWithLogits() + output = softmax_cross_entropy_with_logits(logits, labels) + error0 = 1.0e-6 + diff0 = output.asnumpy() - expect_loss + assert np.all(abs(diff0) < error0) + +test_softmax_cross_entropy_with_logits()