Add expm1 op at GPU back-end

pull/8002/head
peixu_ren 4 years ago
parent 852bf44078
commit b8abcf858a

@ -30,6 +30,13 @@ __global__ void ExponentialKernel(const half *input, half *output, const size_t
return; return;
} }
template <typename T> template <typename T>
__global__ void Expm1Kernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(expm1f(static_cast<float>(input[i])));
}
return;
}
template <typename T>
__global__ void LogarithmKernel(const T *input, T *output, const size_t count) { __global__ void LogarithmKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = logf(input[i]); output[i] = logf(input[i]);
@ -46,21 +53,21 @@ __global__ void LogarithmKernel(const half *input, half *output, const size_t co
template <typename T> template <typename T>
__global__ void Log1pKernel(const T *input, T *output, const size_t count) { __global__ void Log1pKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(log1p(static_cast<double>(input[i]))); output[i] = static_cast<T>(log1pf(static_cast<float>(input[i])));
} }
return; return;
} }
template <typename T> template <typename T>
__global__ void ErfKernel(const T *input, T *output, const size_t count) { __global__ void ErfKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(erf(static_cast<float>(input[i]))); output[i] = static_cast<T>(erff(static_cast<float>(input[i])));
} }
return; return;
} }
template <typename T> template <typename T>
__global__ void ErfcKernel(const T *input, T *output, const size_t count) { __global__ void ErfcKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(erfc(static_cast<float>(input[i]))); output[i] = static_cast<T>(erfcf(static_cast<float>(input[i])));
} }
return; return;
} }
@ -204,13 +211,13 @@ void Exponential(const T *input, T *output, const size_t count, cudaStream_t cud
return; return;
} }
template <typename T> template <typename T>
void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { void Expm1(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
LogarithmKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); Expm1Kernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return; return;
} }
template <typename T> template <typename T>
void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
NegativeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); LogarithmKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return; return;
} }
template <typename T> template <typename T>
@ -229,6 +236,11 @@ void Erfc(const T *input, T *output, const size_t count, cudaStream_t cuda_strea
return; return;
} }
template <typename T> template <typename T>
void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
NegativeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Reciprocal(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { void Reciprocal(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ReciprocalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); ReciprocalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return; return;
@ -290,11 +302,12 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
} }
template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Log1p<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Log1p<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Erf<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Erf<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Erfc<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Erfc<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Reciprocal<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Square<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Square<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Sqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
@ -307,11 +320,12 @@ template void Zeroslike<float>(float *output, const size_t count, cudaStream_t c
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Log1p<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Log1p<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Erf<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Erf<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Erfc<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Erfc<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Reciprocal<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Square<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Square<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Sqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

@ -21,6 +21,8 @@
template <typename T> template <typename T>
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T> template <typename T>
void Expm1(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T> template <typename T>
void Log1p(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); void Log1p(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);

@ -22,13 +22,13 @@ MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Expm1, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Expm1, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Log1p, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Log1p, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
@ -42,6 +42,10 @@ MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

@ -29,6 +29,7 @@ namespace mindspore {
namespace kernel { namespace kernel {
enum UnaryOptype { enum UnaryOptype {
UNARY_OP_EXP = 0, UNARY_OP_EXP = 0,
UNARY_OP_EXPM1,
UNARY_OP_LOG, UNARY_OP_LOG,
UNARY_OP_LOG1P, UNARY_OP_LOG1P,
UNARY_OP_ERF, UNARY_OP_ERF,
@ -48,6 +49,7 @@ enum UnaryOptype {
UNARY_OP_INVALID_TYPE = 255 UNARY_OP_INVALID_TYPE = 255
}; };
static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP},
{"Expm1", UNARY_OP_EXPM1},
{"Log", UNARY_OP_LOG}, {"Log", UNARY_OP_LOG},
{"Log1p", UNARY_OP_LOG1P}, {"Log1p", UNARY_OP_LOG1P},
{"Erf", UNARY_OP_ERF}, {"Erf", UNARY_OP_ERF},
@ -90,6 +92,10 @@ class UnaryOpGpuKernel : public GpuKernel {
Exponential(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); Exponential(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break; break;
} }
case UNARY_OP_EXPM1: {
Expm1(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_LOG: { case UNARY_OP_LOG: {
Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break; break;

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetExpm1(nn.Cell):
def __init__(self):
super(NetExpm1, self).__init__()
self.expm1 = P.Expm1()
def construct(self, x):
return self.expm1(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_expm1_fp32():
expm1 = NetExpm1()
x = np.random.rand(3, 8).astype(np.float32)
output = expm1(Tensor(x, dtype=dtype.float32))
expect = np.expm1(x)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect) < tol).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_expm1_fp16():
expm1 = NetExpm1()
x = np.random.rand(3, 8).astype(np.float16)
output = expm1(Tensor(x, dtype=dtype.float16))
expect = np.expm1(x)
tol = 1e-3
assert (np.abs(output.asnumpy() - expect) < tol).all()
Loading…
Cancel
Save