add float64 to Addn gpu

pull/13635/head
TFBunny 4 years ago
parent 1965ecb9a1
commit b780e5737c

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,9 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
AddNGpuFwdKernel, double)
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AddNGpuFwdKernel, float)

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ADDN_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ADDN_GPU_KERNEL_H_
#include <memory>
#include <vector>
@ -63,11 +63,18 @@ class AddNGpuFwdKernel : public GpuKernel {
}
const float alpha = 1;
const float beta = 0;
const double dalpha = static_cast<double>(1.0f);
const double dbeta = static_cast<double>(0.0f);
for (size_t i = 0; i < num_input_; i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i);
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else if (cudnn_data_type_ == CUDNN_DATA_DOUBLE) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnAddTensor(cudnn_handle_, &dalpha, input_descriptor_, input_addr,
&(i > 0 ? dalpha : dbeta), input_descriptor_, work_addr),
"cudnnAddTensor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr,
@ -169,4 +176,4 @@ class AddNGpuFwdKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ADDN_GPU_KERNEL_H_

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -55,3 +55,42 @@ def test_net():
[96., 99., 102., 105.]]]]
assert (output.asnumpy() == expect_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_float64():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64)
y = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64)
z = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64)
add = Net()
output = add(Tensor(x), Tensor(y), Tensor(z))
expect_result = np.array([[[[0., 3., 6., 9.],
[12., 15., 18., 21.],
[24., 27., 30., 33.]],
[[36., 39., 42., 45.],
[48., 51., 54., 57.],
[60., 63., 66., 69.]],
[[72., 75., 78., 81.],
[84., 87., 90., 93.],
[96., 99., 102., 105.]]]]).astype(np.float64)
assert (output.asnumpy() == expect_result).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64)
y = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64)
z = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float64)
add = Net()
output = add(Tensor(x), Tensor(y), Tensor(z))
expect_result = np.array([[[[0., 3., 6., 9.],
[12., 15., 18., 21.],
[24., 27., 30., 33.]],
[[36., 39., 42., 45.],
[48., 51., 54., 57.],
[60., 63., 66., 69.]],
[[72., 75., 78., 81.],
[84., 87., 90., 93.],
[96., 99., 102., 105.]]]]).astype(np.float64)
assert (output.asnumpy() == expect_result).all()

Loading…
Cancel
Save