From 386b219ffc3b4de9b6a8ed7aadaa248bd90e139b Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Fri, 15 Jan 2021 16:48:58 -0500 Subject: [PATCH] initial commit ci fix fix ci --- .../gpu/arrays/one_hot_gpu_kernel.cc | 17 +++++++++- .../gpu/cuda_impl/one_hot_impl.cu | 6 ++++ mindspore/ops/operations/nn_ops.py | 4 +-- tests/st/ops/gpu/test_one_hot_op.py | 31 ++++++++++++------- 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc index e764a08dc8..11b0dcf678 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h" @@ -32,5 +33,19 @@ MS_REG_GPU_KERNEL_TWO(OneHot, .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), OneHotGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + OneHotGpuFwdKernel, float, int64_t) +MS_REG_GPU_KERNEL_TWO(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + OneHotGpuFwdKernel, half, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu index b0adf756e1..9447e6e0e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu @@ -49,3 +49,9 @@ template void OneHot(const int *indices, size_t depth, const float * size_t left_dim_size, size_t right_dim_size, float *output, cudaStream_t cuda_stream); template void OneHot(const int *indices, size_t depth, const half *on_value, const half *off_value, size_t left_dim_size, size_t right_dim_size, half *output, cudaStream_t cuda_stream); +template void OneHot(const int64_t *indices, size_t depth, const float *on_value, + const float *off_value, size_t left_dim_size, size_t right_dim_size, float *output, + cudaStream_t cuda_stream); +template void OneHot(const int64_t *indices, size_t depth, const half *on_value, const half *off_value, + size_t left_dim_size, size_t right_dim_size, half *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2de71eceb3..c706b8e1b3 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3030,7 +3030,7 @@ class OneHot(PrimitiveWithInfer): Inputs: - **indices** (Tensor) - A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`. - Data type must be int32. + Data type must be int32 or int64. - **depth** (int) - A scalar defining the depth of the one hot dimension. - **on_value** (Tensor) - A value to fill in output when `indices[j] = i`. With data type of float16 or float32. - **off_value** (Tensor) - A value to fill in output when `indices[j] != i`. @@ -3060,7 +3060,7 @@ class OneHot(PrimitiveWithInfer): def __infer__(self, indices, depth, on_value, off_value): # check type - validator.check_tensor_dtype_valid("indices", indices['dtype'], (mstype.int32,), self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], (mstype.int32, mstype.int64), self.name) validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name) args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']} validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) diff --git a/tests/st/ops/gpu/test_one_hot_op.py b/tests/st/ops/gpu/test_one_hot_op.py index 59acb35875..86a19a07be 100644 --- a/tests/st/ops/gpu/test_one_hot_op.py +++ b/tests/st/ops/gpu/test_one_hot_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,16 +44,13 @@ class NetOneHot(nn.Cell): self.one_hot_3(indices3), self.one_hot_4(indices4)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_one_hot(): - one_hot = NetOneHot() - indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) - indices2 = Tensor(np.array([1, 2, 3]).astype(np.int32)) - indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32)) - indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) - output = one_hot(indices1, indices2, indices3, indices4) +def one_hot(nptype): + one_hot_net = NetOneHot() + indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(nptype)) + indices2 = Tensor(np.array([1, 2, 3]).astype(nptype)) + indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(nptype)) + indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(nptype)) + output = one_hot_net(indices1, indices2, indices3, indices4) expect_0 = np.array([ [[2., 3., 3., 3., 3., 3.], [3., 2., 3., 3., 3., 3.]], [[3., 3., 3., 3., 2., 3.], [3., 3., 3., 3., 3., 2.]], @@ -80,3 +77,15 @@ def test_one_hot(): assert (output[1].asnumpy() == expect_1).all() assert (output[2].asnumpy() == expect_2).all() assert (output[3].asnumpy() == expect_3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_one_hot_int32(): + one_hot(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_one_hot_int64(): + one_hot(np.int64)