!7548 Support elu and elugrad with dtype float and half on gpu

Merge pull request !7548 from zhouyuanshen/master
pull/7548/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e24b50f559

@ -33,6 +33,11 @@ MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

@ -84,6 +84,10 @@ class ActivationGpuFwdKernel : public GpuKernel {
}
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) {
float alpha = GetAttr<float>(kernel_node, "alpha");
coef = static_cast<double>(alpha);
}
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef),
"cudnnSetActivationDescriptor failed");
@ -137,7 +141,7 @@ class ActivationGpuFwdKernel : public GpuKernel {
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU},
{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU},
{"Tanh", CUDNN_ACTIVATION_TANH},
{"ELU", CUDNN_ACTIVATION_ELU},
{"Elu", CUDNN_ACTIVATION_ELU},
{"Sigmoid", CUDNN_ACTIVATION_SIGMOID}};
cudnnHandle_t cudnn_handle_;

@ -45,6 +45,15 @@ MS_REG_GPU_KERNEL_ONE(
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
EluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
EluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
SigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

@ -91,6 +91,7 @@ class ActivationGradGpuKernel : public GpuKernel {
}
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef),
"SetActivationDescriptor failed");
@ -143,7 +144,7 @@ class ActivationGradGpuKernel : public GpuKernel {
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU},
{"ReLU6Grad", CUDNN_ACTIVATION_CLIPPED_RELU},
{"TanhGrad", CUDNN_ACTIVATION_TANH},
{"ELUGrad", CUDNN_ACTIVATION_ELU},
{"EluGrad", CUDNN_ACTIVATION_ELU},
{"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}};
cudnnHandle_t cudnn_handle_;
cudnnActivationDescriptor_t activation_desc_;

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
class NetEluGrad(nn.Cell):
def __init__(self):
super(NetEluGrad, self).__init__()
self.eluGrad = G.EluGrad()
def construct(self, x, dy):
return self.eluGrad(dy, x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_elu_grad_fp16():
x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16))
dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16))
expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16)
error = np.ones(shape=[2, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
elu_grad = NetEluGrad()
output = elu_grad(x, dy)
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_elu_grad_fp32():
x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float32))
dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float32))
expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float32)
error = np.ones(shape=[2, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
elu_grad = NetEluGrad()
output = elu_grad(x, dy)
diff = output.asnumpy() - expect
assert np.all(diff < error)

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class NetElu(nn.Cell):
def __init__(self):
super(NetElu, self).__init__()
self.elu = P.Elu()
def construct(self, x):
return self.elu(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_elu_fp16():
x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float16))
expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float16)
error = np.ones(shape=[2, 3]) * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
elu = NetElu()
output = elu(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
elu = NetElu()
output = elu(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_elu_fp32():
x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float32))
expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float32)
error = np.ones(shape=[2, 3]) * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
elu = NetElu()
output = elu(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
elu = NetElu()
output = elu(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
Loading…
Cancel
Save