From b8abcf858a44a61c978930e9ee6328e1a3ce8dee Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 29 Oct 2020 10:41:32 -0400 Subject: [PATCH] Add expm1 op at GPU back-end --- .../gpu/cuda_impl/unary_op_impl.cu | 32 ++++++++--- .../gpu/cuda_impl/unary_op_impl.cuh | 2 + .../gpu/math/unary_op_gpu_kernel.cc | 12 ++-- .../gpu/math/unary_op_gpu_kernel.h | 6 ++ tests/st/ops/gpu/test_expm1_op.py | 56 +++++++++++++++++++ 5 files changed, 95 insertions(+), 13 deletions(-) create mode 100644 tests/st/ops/gpu/test_expm1_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index 4d40cc5358..890124d6ae 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -30,6 +30,13 @@ __global__ void ExponentialKernel(const half *input, half *output, const size_t return; } template +__global__ void Expm1Kernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = static_cast(expm1f(static_cast(input[i]))); + } + return; +} +template __global__ void LogarithmKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = logf(input[i]); @@ -46,21 +53,21 @@ __global__ void LogarithmKernel(const half *input, half *output, const size_t co template __global__ void Log1pKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = static_cast(log1p(static_cast(input[i]))); + output[i] = static_cast(log1pf(static_cast(input[i]))); } return; } template __global__ void ErfKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = static_cast(erf(static_cast(input[i]))); + output[i] = static_cast(erff(static_cast(input[i]))); } return; } template __global__ void ErfcKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = static_cast(erfc(static_cast(input[i]))); + output[i] = static_cast(erfcf(static_cast(input[i]))); } return; } @@ -204,13 +211,13 @@ void Exponential(const T *input, T *output, const size_t count, cudaStream_t cud return; } template -void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { - LogarithmKernel<<>>(input, output, count); +void Expm1(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + Expm1Kernel<<>>(input, output, count); return; } template -void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { - NegativeKernel<<>>(input, output, count); +void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + LogarithmKernel<<>>(input, output, count); return; } template @@ -229,6 +236,11 @@ void Erfc(const T *input, T *output, const size_t count, cudaStream_t cuda_strea return; } template +void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + NegativeKernel<<>>(input, output, count); + return; +} +template void Reciprocal(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { ReciprocalKernel<<>>(input, output, count); return; @@ -290,11 +302,12 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre } template void Exponential(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void Expm1(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); -template void Negative(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Log1p(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Erf(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Erfc(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void Negative(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Reciprocal(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Square(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Sqrt(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); @@ -307,11 +320,12 @@ template void Zeroslike(float *output, const size_t count, cudaStream_t c template void Abs(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Exponential(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void Expm1(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); -template void Negative(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Log1p(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Erf(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Erfc(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void Negative(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Reciprocal(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Square(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Sqrt(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index fadcda56d2..1bb94e07e2 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -21,6 +21,8 @@ template void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); template +void Expm1(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); +template void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); template void Log1p(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index 009b20dc9b..7def6ab7f0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -22,13 +22,13 @@ MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), +MS_REG_GPU_KERNEL_ONE(Expm1, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), +MS_REG_GPU_KERNEL_ONE(Expm1, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), +MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), +MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Log1p, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) @@ -42,6 +42,10 @@ MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index 784ed8ad75..591768a1b9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -29,6 +29,7 @@ namespace mindspore { namespace kernel { enum UnaryOptype { UNARY_OP_EXP = 0, + UNARY_OP_EXPM1, UNARY_OP_LOG, UNARY_OP_LOG1P, UNARY_OP_ERF, @@ -48,6 +49,7 @@ enum UnaryOptype { UNARY_OP_INVALID_TYPE = 255 }; static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, + {"Expm1", UNARY_OP_EXPM1}, {"Log", UNARY_OP_LOG}, {"Log1p", UNARY_OP_LOG1P}, {"Erf", UNARY_OP_ERF}, @@ -90,6 +92,10 @@ class UnaryOpGpuKernel : public GpuKernel { Exponential(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_EXPM1: { + Expm1(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_LOG: { Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; diff --git a/tests/st/ops/gpu/test_expm1_op.py b/tests/st/ops/gpu/test_expm1_op.py new file mode 100644 index 0000000000..c072563ba0 --- /dev/null +++ b/tests/st/ops/gpu/test_expm1_op.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetExpm1(nn.Cell): + def __init__(self): + super(NetExpm1, self).__init__() + self.expm1 = P.Expm1() + + def construct(self, x): + return self.expm1(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_expm1_fp32(): + expm1 = NetExpm1() + x = np.random.rand(3, 8).astype(np.float32) + output = expm1(Tensor(x, dtype=dtype.float32)) + expect = np.expm1(x) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect) < tol).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_expm1_fp16(): + expm1 = NetExpm1() + x = np.random.rand(3, 8).astype(np.float16) + output = expm1(Tensor(x, dtype=dtype.float16)) + expect = np.expm1(x) + tol = 1e-3 + assert (np.abs(output.asnumpy() - expect) < tol).all()