add float64 support to SigmoidCrossEntropyWithLogits and Grad

pull/14508/head
TFBunny 4 years ago
parent 5312cb372e
commit 4de6b25d23

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -21,10 +21,10 @@ __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const
const T *dout_addr, T *outputs) { const T *dout_addr, T *outputs) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (logits[i] >= 0) { if (logits[i] >= 0) {
outputs[i] = (1. / (1. + exp(-logits[i])) - labels[i]) * dout_addr[i]; outputs[i] = (static_cast<T>(1.) / (static_cast<T>(1.) + exp(-logits[i])) - labels[i]) * dout_addr[i];
} else { } else {
const T exp_val = exp(logits[i]); const T exp_val = exp(logits[i]);
outputs[i] = (exp_val / (1. + exp_val) - labels[i]) * dout_addr[i]; outputs[i] = (exp_val / (static_cast<T>(1.) + exp_val) - labels[i]) * dout_addr[i];
} }
} }
} }
@ -39,3 +39,6 @@ void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const
template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits, template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits,
const float *labels, const float *dout_addr, const float *labels, const float *dout_addr,
float *outputs, cudaStream_t cuda_stream); float *outputs, cudaStream_t cuda_stream);
template void SigmoidCrossEntropyWithLogitsGrad<double, double>(const size_t size, const double *logits,
const double *labels, const double *dout_addr,
double *outputs, cudaStream_t cuda_stream);

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,7 +20,8 @@ template <typename T, typename S>
__global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { __global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
const T reverse_factor = static_cast<T>(logits[i] >= 0); const T reverse_factor = static_cast<T>(logits[i] >= 0);
outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); outputs[i] =
log1p(exp(logits[i] - static_cast<T>(2) * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor);
} }
} }
@ -32,3 +33,6 @@ void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *
template void SigmoidCrossEntropyWithLogits<float, float>(const size_t size, const float *logits, const float *labels, template void SigmoidCrossEntropyWithLogits<float, float>(const size_t size, const float *logits, const float *labels,
float *outputs, cudaStream_t cuda_stream); float *outputs, cudaStream_t cuda_stream);
template void SigmoidCrossEntropyWithLogits<double, double>(const size_t size, const double *logits,
const double *labels, double *outputs,
cudaStream_t cuda_stream);

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO(
SigmoidCrossEntropyWithLogits, SigmoidCrossEntropyWithLogits,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SigmoidCrossEntropyWithLogitsGpuKernel, float, float) SigmoidCrossEntropyWithLogitsGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
SigmoidCrossEntropyWithLogits,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SigmoidCrossEntropyWithLogitsGpuKernel, double, double)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad,
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
SigmoidCrossEntropyWithLogitsGradGpuKernel, double, double)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -31,32 +31,43 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell):
return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout)
@pytest.mark.level0 def sigmoid_cross_entropy_with_logits_grad(nptype):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits():
logits = Tensor(np.array([[1, 1, 2], logits = Tensor(np.array([[1, 1, 2],
[1, 2, 1], [1, 2, 1],
[2, 1, 1]]).astype(np.float32)) [2, 1, 1]]).astype(nptype))
labels = Tensor(np.array([[0, 0, 1], labels = Tensor(np.array([[0, 0, 1],
[0, 1, 0], [0, 1, 0],
[1, 0, 0]]).astype(np.float32)) [1, 0, 0]]).astype(nptype))
dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) dout = Tensor(np.ones(shape=[3, 3]).astype(nptype))
expect = np.array([[0.731059, 0.731059, -0.119203], expect = np.array([[0.731059, 0.731059, -0.119203],
[0.731059, -0.119203, 0.731059], [0.731059, -0.119203, 0.731059],
[-0.119203, 0.731059, 0.731059]]).astype(np.float32) [-0.119203, 0.731059, 0.731059]]).astype(nptype)
error = np.ones(shape=[3, 3]) * 1.0e-6 error = np.ones(shape=[3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() net = NetSigmoidCrossEntropyWithLogits()
output = sigmoid_cross_entropy_with_logits(logits, labels, dout) output = net(logits, labels, dout)
diff = output.asnumpy() - expect diff = output.asnumpy() - expect
assert np.all(abs(diff) < error) assert np.all(abs(diff) < error)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() net = NetSigmoidCrossEntropyWithLogits()
output = sigmoid_cross_entropy_with_logits(logits, labels, dout) output = net(logits, labels, dout)
diff = output.asnumpy() - expect diff = output.asnumpy() - expect
assert np.all(abs(diff) < error) assert np.all(abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits_float32():
sigmoid_cross_entropy_with_logits_grad(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits_float64():
sigmoid_cross_entropy_with_logits_grad(np.float64)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -31,30 +31,41 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell):
return self.loss(logits, labels) return self.loss(logits, labels)
@pytest.mark.level0 def sigmoid_cross_entropy_with_logits(nptype):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits():
logits = Tensor(np.array([[1, 1, 2], logits = Tensor(np.array([[1, 1, 2],
[1, 2, 1], [1, 2, 1],
[2, 1, 1]]).astype(np.float32)) [2, 1, 1]]).astype(nptype))
labels = Tensor(np.array([[0, 0, 1], labels = Tensor(np.array([[0, 0, 1],
[0, 1, 0], [0, 1, 0],
[1, 0, 0]]).astype(np.float32)) [1, 0, 0]]).astype(nptype))
expect_loss = np.array([[1.313262, 1.313262, 0.126928], expect_loss = np.array([[1.313262, 1.313262, 0.126928],
[1.313262, 0.126928, 1.313262], [1.313262, 0.126928, 1.313262],
[0.126928, 1.313262, 1.313262]]).astype(np.float32) [0.126928, 1.313262, 1.313262]]).astype(nptype)
error = np.ones(shape=[3, 3]) * 1.0e-6 error = np.ones(shape=[3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() net = NetSigmoidCrossEntropyWithLogits()
output = sigmoid_cross_entropy_with_logits(logits, labels) output = net(logits, labels)
diff = output.asnumpy() - expect_loss diff = output.asnumpy() - expect_loss
assert np.all(abs(diff) < error) assert np.all(abs(diff) < error)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() net = NetSigmoidCrossEntropyWithLogits()
output = sigmoid_cross_entropy_with_logits(logits, labels) output = net(logits, labels)
diff = output.asnumpy() - expect_loss diff = output.asnumpy() - expect_loss
assert np.all(abs(diff) < error) assert np.all(abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits_float32():
sigmoid_cross_entropy_with_logits(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits_float64():
sigmoid_cross_entropy_with_logits(np.float64)

Loading…
Cancel
Save