diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc index eb6342e9d2..2b6848fd36 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,10 +22,20 @@ MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32 StridedSliceGpuKernel, float) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), StridedSliceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + StridedSliceGpuKernel, int64_t) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), StridedSliceGpuKernel, int) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), StridedSliceGpuKernel, short) // NOLINT +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + StridedSliceGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + StridedSliceGpuKernel, uint64_t) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + StridedSliceGpuKernel, uint32_t) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + StridedSliceGpuKernel, uint16_t) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), StridedSliceGpuKernel, uchar) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h index adaa38b542..3aa3f4df06 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H #include #include @@ -210,4 +210,4 @@ class StridedSliceGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc index 58e2c11098..ceaccaa26f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,10 +22,20 @@ MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFlo StridedSliceGradGpuKernel, float) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), StridedSliceGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + StridedSliceGradGpuKernel, int64_t) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), StridedSliceGradGpuKernel, int) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), StridedSliceGradGpuKernel, short) // NOLINT +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + StridedSliceGradGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + StridedSliceGradGpuKernel, uint64_t) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + StridedSliceGradGpuKernel, uint32_t) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + StridedSliceGradGpuKernel, uint16_t) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), StridedSliceGradGpuKernel, uchar) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h index d6c0680963..5c1e6ef425 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H #include #include @@ -211,4 +211,4 @@ class StridedSliceGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu index ae63fc0649..9b733ebf86 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -159,7 +159,6 @@ void StridedSliceGrad(const std::vector &dy_shape, const std::vector(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream); @@ -167,7 +166,6 @@ template void CalSliceGrad(const size_t input_size, const float *dy, cons const std::vector begin, const std::vector size, float *output, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream); @@ -175,7 +173,6 @@ template void CalSliceGrad(const size_t input_size, const half *dy, const const std::vector begin, const std::vector size, half *output, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream); @@ -183,8 +180,6 @@ template void CalSliceGrad(const size_t input_size, const int *dy, const st const std::vector begin, const std::vector size, int *output, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, short *addr, const float value, // NOLINT - cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const short *input, short *output, // NOLINT @@ -195,8 +190,6 @@ template void CalSliceGrad(const size_t input_size, const short *dy, // short *output, // NOLINT cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, unsigned char *addr, const float value, - cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output, @@ -206,8 +199,6 @@ template void CalSliceGrad(const size_t input_size, const unsigne const std::vector size, unsigned char *output, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, int64_t *addr, const float value, - cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const int64_t *input, int64_t *output, @@ -216,7 +207,6 @@ template void CalSliceGrad(const size_t input_size, const int64_t *dy, const std::vector begin, const std::vector size, int64_t *output, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); @@ -224,12 +214,37 @@ template void CalSliceGrad(const size_t input_size, const bool *dy, const const std::vector begin, const std::vector size, bool *output, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, int64_t *addr, const float value, + cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, short *addr, const float value, // NOLINT + cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, int8_t *addr, const float value, + cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, uint64_t *addr, const float value, + cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, uint32_t *addr, const float value, + cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, uint16_t *addr, const float value, + cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, unsigned char *addr, const float value, + cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); + +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, + const bool *input, bool *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const float *input, float *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const half *input, half *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, + const int64_t *input, int64_t *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const int *input, int *output, cudaStream_t cuda_stream); @@ -238,20 +253,32 @@ template void StridedSlice(const std::vector &input_shape, const std::ve const short *input, short *output, cudaStream_t cuda_stream); // NOLINT template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, - const unsigned char *input, unsigned char *output, cudaStream_t cuda_stream); + const int8_t *input, int8_t *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, - const bool *input, bool *output, cudaStream_t cuda_stream); + const uint64_t *input, uint64_t *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, - const int64_t *input, int64_t *output, cudaStream_t cuda_stream); + const uint32_t *input, uint32_t *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, + const uint16_t *input, uint16_t *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, + const unsigned char *input, unsigned char *output, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const bool *dy, + bool *dx, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const float *dy, float *dx, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const half *dy, half *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, + const int64_t *dy, int64_t *dx, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const int *dy, int *dx, cudaStream_t cuda_stream); @@ -261,10 +288,16 @@ template void StridedSliceGrad(const std::vector &dy_shape, const std::v short *dx, cudaStream_t cuda_stream); // NOLINT template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, - const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream); + const int8_t *dy, int8_t *dx, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, - const std::vector &strides, const std::vector &dx_shape, const bool *dy, - bool *dx, cudaStream_t cuda_stream); + const std::vector &strides, const std::vector &dx_shape, + const uint64_t *dy, uint64_t *dx, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, - const int64_t *dy, int64_t *dx, cudaStream_t cuda_stream); + const uint32_t *dy, uint32_t *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, + const uint16_t *dy, uint16_t *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, + const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh index abc32af340..2bd4ce302d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_IMPL_CUH_ #include #include @@ -39,4 +39,4 @@ void StridedSliceGrad(const std::vector &dy_shape, const std::vector void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_IMPL_CUH_ diff --git a/tests/st/ops/gpu/test_print_op.py b/tests/st/ops/gpu/test_print_op.py index cd53976c75..2c3f2effc7 100644 --- a/tests/st/ops/gpu/test_print_op.py +++ b/tests/st/ops/gpu/test_print_op.py @@ -43,9 +43,21 @@ class PrintNetTwoInputs(nn.Cell): return x +class PrintNetIndex(nn.Cell): + def __init__(self): + super(PrintNetIndex, self).__init__() + self.op = P.Print() + + def construct(self, x): + self.op(x[0][0][6][3]) + return x + + def print_testcase(nptype): # large shape x = np.arange(20808).reshape(6, 3, 34, 34).astype(nptype) + # a value that can be stored as int8_t + x[0][0][6][3] = 125 # small shape y = np.arange(9).reshape(3, 3).astype(nptype) x = Tensor(x) @@ -54,8 +66,10 @@ def print_testcase(nptype): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") net_1 = PrintNetOneInput() net_2 = PrintNetTwoInputs() + net_3 = PrintNetIndex() net_1(x) net_2(x, y) + net_3(x) @pytest.mark.level0 diff --git a/tests/st/ops/gpu/test_stridedslice_grad_op.py b/tests/st/ops/gpu/test_stridedslice_grad_op.py index 77cb7e6009..093333440b 100644 --- a/tests/st/ops/gpu/test_stridedslice_grad_op.py +++ b/tests/st/ops/gpu/test_stridedslice_grad_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -245,12 +245,54 @@ def strided_slice_grad(nptype): def test_strided_slice_grad_float32(): strided_slice_grad(np.float32) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_float16(): + strided_slice_grad(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_int64(): + strided_slice_grad(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_int32(): + strided_slice_grad(np.int32) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_strided_slice_grad_int16(): strided_slice_grad(np.int16) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_int8(): + strided_slice_grad(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_uint64(): + strided_slice_grad(np.uint64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_uint32(): + strided_slice_grad(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_uint16(): + strided_slice_grad(np.uint16) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard diff --git a/tests/st/ops/gpu/test_stridedslice_op.py b/tests/st/ops/gpu/test_stridedslice_op.py index 5e88b744fa..ecbb495a89 100644 --- a/tests/st/ops/gpu/test_stridedslice_op.py +++ b/tests/st/ops/gpu/test_stridedslice_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -108,12 +108,54 @@ def strided_slice(nptype): def test_strided_slice_float32(): strided_slice(np.float32) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_float16(): + strided_slice(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_int64(): + strided_slice(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_int32(): + strided_slice(np.int32) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_strided_slice_int16(): strided_slice(np.int16) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_int8(): + strided_slice(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_uint64(): + strided_slice(np.uint64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_uint32(): + strided_slice(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_uint16(): + strided_slice(np.uint16) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard