diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc index 86c7d8c108..badaedf99a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,9 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + AddNGpuFwdKernel, double) MS_REG_GPU_KERNEL_ONE( AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), AddNGpuFwdKernel, float) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h index b017c741c5..b692727e3d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ADDN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ADDN_GPU_KERNEL_H_ #include #include @@ -63,11 +63,18 @@ class AddNGpuFwdKernel : public GpuKernel { } const float alpha = 1; const float beta = 0; + const double dalpha = static_cast(1.0f); + const double dbeta = static_cast(0.0f); for (size_t i = 0; i < num_input_; i++) { T *input_addr = GetDeviceAddress(inputs, i); if (cudnn_data_type_ == CUDNN_DATA_INT32) { ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr, reinterpret_cast(stream_ptr)); + } else if (cudnn_data_type_ == CUDNN_DATA_DOUBLE) { + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnAddTensor(cudnn_handle_, &dalpha, input_descriptor_, input_addr, + &(i > 0 ? dalpha : dbeta), input_descriptor_, work_addr), + "cudnnAddTensor failed"); } else { CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, @@ -169,4 +176,4 @@ class AddNGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ADDN_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_addn_op.py b/tests/st/ops/gpu/test_addn_op.py index 45926742a9..0fb7846960 100644 --- a/tests/st/ops/gpu/test_addn_op.py +++ b/tests/st/ops/gpu/test_addn_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -55,3 +55,42 @@ def test_net(): [96., 99., 102., 105.]]]] assert (output.asnumpy() == expect_result).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net_float64(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64) + y = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64) + z = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64) + add = Net() + output = add(Tensor(x), Tensor(y), Tensor(z)) + expect_result = np.array([[[[0., 3., 6., 9.], + [12., 15., 18., 21.], + [24., 27., 30., 33.]], + [[36., 39., 42., 45.], + [48., 51., 54., 57.], + [60., 63., 66., 69.]], + [[72., 75., 78., 81.], + [84., 87., 90., 93.], + [96., 99., 102., 105.]]]]).astype(np.float64) + assert (output.asnumpy() == expect_result).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64) + y = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64) + z = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64) + add = Net() + output = add(Tensor(x), Tensor(y), Tensor(z)) + expect_result = np.array([[[[0., 3., 6., 9.], + [12., 15., 18., 21.], + [24., 27., 30., 33.]], + [[36., 39., 42., 45.], + [48., 51., 54., 57.], + [60., 63., 66., 69.]], + [[72., 75., 78., 81.], + [84., 87., 90., 93.], + [96., 99., 102., 105.]]]]).astype(np.float64) + assert (output.asnumpy() == expect_result).all()