pull/7693/head
wanyiming 4 years ago
parent 17764803ef
commit 0b74b0f50b

@ -0,0 +1,71 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void SigmoidCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (const uint64_t &d : x_shape) {
tensor_size_ *= d;
}
}
bool SigmoidCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
}
return true;
}
template <typename T>
void SigmoidCrossEntropyWithLogitsCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto logits_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto labels_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
T zero = (T)0.0;
T one = (T)1.0;
T two = (T)2.0;
for (uint64_t i = 0; i < tensor_size_; ++i) {
if (logits_addr[i] >= zero) {
output_addr[i] = log1p(exp(logits_addr[i] - two * logits_addr[i])) - logits_addr[i] * (labels_addr[i] - one);
} else {
output_addr[i] = log1p(exp(logits_addr[i])) - logits_addr[i] * labels_addr[i];
}
}
}
void SigmoidCrossEntropyWithLogitsCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel needs 2 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel expects 1 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class SigmoidCrossEntropyWithLogitsCPUKernel : public CPUKernel {
public:
SigmoidCrossEntropyWithLogitsCPUKernel() = default;
~SigmoidCrossEntropyWithLogitsCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
uint64_t tensor_size_{1};
};
MS_REG_CPU_KERNEL(
SigmoidCrossEntropyWithLogits,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SigmoidCrossEntropyWithLogitsCPUKernel);
MS_REG_CPU_KERNEL(
SigmoidCrossEntropyWithLogits,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SigmoidCrossEntropyWithLogitsCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void SigmoidCrossEntropyWithLogitsGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (const uint64_t &d : x_shape) {
tensor_size_ *= d;
}
}
bool SigmoidCrossEntropyWithLogitsGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
}
return true;
}
template <typename T>
void SigmoidCrossEntropyWithLogitsGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto logits_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto labels_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto dloss_addr = reinterpret_cast<T *>(inputs[2]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
T zero = (T)0.0;
T one = (T)1.0;
for (uint64_t i = 0; i < tensor_size_; ++i) {
if (logits_addr[i] >= zero) {
output_addr[i] = (one / (one + exp(-logits_addr[i])) - labels_addr[i]) * dloss_addr[i];
} else {
const T exp_val = exp(logits_addr[i]);
output_addr[i] = (exp_val / (one + exp_val) - labels_addr[i]) * dloss_addr[i];
}
}
}
void SigmoidCrossEntropyWithLogitsGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel needs 2 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel expects 1 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class SigmoidCrossEntropyWithLogitsGradCPUKernel : public CPUKernel {
public:
SigmoidCrossEntropyWithLogitsGradCPUKernel() = default;
~SigmoidCrossEntropyWithLogitsGradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
uint64_t tensor_size_{1};
};
MS_REG_CPU_KERNEL(SigmoidCrossEntropyWithLogitsGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SigmoidCrossEntropyWithLogitsGradCPUKernel);
MS_REG_CPU_KERNEL(SigmoidCrossEntropyWithLogitsGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SigmoidCrossEntropyWithLogitsGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_CPU_KERNEL_H_

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
class NetSigmoidCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(NetSigmoidCrossEntropyWithLogits, self).__init__()
self.sigmoid_cross_entropy_with_logits_grad = G.SigmoidCrossEntropyWithLogitsGrad()
def construct(self, logits, labels, dout):
return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits():
logits = Tensor(np.array([[1, 1, 2],
[1, 2, 1],
[2, 1, 1]]).astype(np.float32))
labels = Tensor(np.array([[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]).astype(np.float32))
dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32))
expect = np.array([[0.731059, 0.731059, -0.119203],
[0.731059, -0.119203, 0.731059],
[-0.119203, 0.731059, 0.731059]]).astype(np.float32)
error = np.ones(shape=[3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits()
output = sigmoid_cross_entropy_with_logits(logits, labels, dout)
diff = output.asnumpy() - expect
assert np.all(abs(diff) < error)

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class NetSigmoidCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(NetSigmoidCrossEntropyWithLogits, self).__init__()
self.loss = P.SigmoidCrossEntropyWithLogits()
def construct(self, logits, labels):
return self.loss(logits, labels)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits():
logits = Tensor(np.array([[1, 1, 2],
[1, 2, 1],
[2, 1, 1]]).astype(np.float32))
labels = Tensor(np.array([[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]).astype(np.float32))
expect_loss = np.array([[1.313262, 1.313262, 0.126928],
[1.313262, 0.126928, 1.313262],
[0.126928, 1.313262, 1.313262]]).astype(np.float32)
error = np.ones(shape=[3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits()
output = sigmoid_cross_entropy_with_logits(logits, labels)
diff = output.asnumpy() - expect_loss
assert np.all(abs(diff) < error)
Loading…
Cancel
Save