diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc index f5979dc62d..1a88d0863f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc @@ -27,5 +27,14 @@ MS_REG_GPU_KERNEL_ONE(Concat, MS_REG_GPU_KERNEL_ONE( Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ConcatV2GpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(Concat, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + ConcatV2GpuFwdKernel, short) // NOLINT +MS_REG_GPU_KERNEL_ONE(Concat, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + ConcatV2GpuFwdKernel, char) +MS_REG_GPU_KERNEL_ONE(Concat, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + ConcatV2GpuFwdKernel, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc index 38f168a9b7..141e28daf6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc @@ -1,33 +1,42 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - GatherNd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - GatherNdGpuFwdKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - GatherNd, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - GatherNdGpuFwdKernel, half, int) -MS_REG_GPU_KERNEL_TWO( - GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - GatherNdGpuFwdKernel, int, int) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherNdGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + GatherNdGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + GatherNdGpuFwdKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), + GatherNdGpuFwdKernel, short, int) // NOLINT +MS_REG_GPU_KERNEL_TWO( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + GatherNdGpuFwdKernel, char, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + GatherNdGpuFwdKernel, bool, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc index 5ecb9d2a55..7be294b591 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc @@ -24,5 +24,11 @@ MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16 StridedSliceGpuKernel, half) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), StridedSliceGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + StridedSliceGpuKernel, short) // NOLINT +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + StridedSliceGpuKernel, char) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + StridedSliceGpuKernel, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc index bbcce07a09..cf28ce0179 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc @@ -24,5 +24,11 @@ MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFlo StridedSliceGradGpuKernel, half) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), StridedSliceGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + StridedSliceGradGpuKernel, short) // NOLINT +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + StridedSliceGradGpuKernel, char) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + StridedSliceGradGpuKernel, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu index b45d2749a3..d8a660139d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu @@ -36,6 +36,26 @@ __global__ void CheckValidKernel(const size_t size, const T *box, const T *img_m return; } +template +__global__ void CheckValidKernel(const size_t size, const char *box, const char *img_metas, S *valid) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const size_t left_x = i * 4; + const size_t left_y = i * 4 + 1; + const size_t right_x = i * 4 + 2; + const size_t right_y = i * 4 + 3; + + S valid_flag = false; + valid_flag |= !((unsigned int)box[left_x] >= 0); + valid_flag |= !((unsigned int)box[left_y] >= 0); + valid_flag |= !((unsigned int)img_metas[0] * (unsigned int)img_metas[2] - 1 >= (unsigned int)box[right_x]); + valid_flag |= !((unsigned int)img_metas[1] * (unsigned int)img_metas[2] - 1 >= (unsigned int)box[right_y]); + + valid[i] = !valid_flag; + } + + return; +} + template void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream) { CheckValidKernel<<>>(size, box, img_metas, valid); @@ -45,3 +65,7 @@ template void CheckValid(const size_t &size, const float *box, const float *img_ cudaStream_t cuda_stream); template void CheckValid(const size_t &size, const half *box, const half *img_metas, bool *valid, cudaStream_t cuda_stream); +template void CheckValid(const size_t &size, const short *box, const short *img_metas, bool *valid, // NOLINT + cudaStream_t cuda_stream); +template void CheckValid(const size_t &size, const char *box, const char *img_metas, bool *valid, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu index 4866d61dd9..fe726f9550 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu @@ -67,3 +67,15 @@ template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis, int* len_axis, half** inputs, half* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, short** inputs, short* output, // NOLINT + cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, char** inputs, char* output, + cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, bool** inputs, bool* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu index 3d02723218..adfa0adacb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu @@ -63,3 +63,12 @@ template void GatherNd(half *input, int *indices, half *output, const template void GatherNd(int *input, int *indices, int *output, const size_t &output_dim0, const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, int *batch_strides, cudaStream_t stream); +template void GatherNd(short *input, int *indices, short *output, const size_t &output_dim0, // NOLINT + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); +template void GatherNd(char *input, int *indices, char *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); +template void GatherNd(bool *input, int *indices, bool *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu index 6e73e29b5a..140298d44c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -163,6 +163,7 @@ template void Slice4DKernel(const int s1, const int s2, const int s3, const int template void CalSliceGrad(const size_t input_size, const float *dy, const std::vector in_shape, const std::vector begin, const std::vector size, float *output, cudaStream_t cuda_stream); + template void FillDeviceArray(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, @@ -170,6 +171,7 @@ template void Slice4DKernel(const int s1, const int s2, const int s3, const int template void CalSliceGrad(const size_t input_size, const half *dy, const std::vector in_shape, const std::vector begin, const std::vector size, half *output, cudaStream_t cuda_stream); + template void FillDeviceArray(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream); template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, @@ -178,6 +180,31 @@ template void CalSliceGrad(const size_t input_size, const int *dy, const st const std::vector begin, const std::vector size, int *output, cudaStream_t cuda_stream); +// NOLINTNEXTLINE +template void FillDeviceArray(const size_t input_size, short *addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, + const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, + const short *input, short *output, cudaStream_t stream); // NOLINT +template void CalSliceGrad(const size_t input_size, const short *dy, const std::vector in_shape, // NOLINT + const std::vector begin, const std::vector size, short *output, // NOLINT + cudaStream_t cuda_stream); + +template void FillDeviceArray(const size_t input_size, char *addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, + const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, + const char *input, char *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const char *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, char *output, + cudaStream_t cuda_stream); + +template void FillDeviceArray(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, + const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, + const bool *input, bool *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const bool *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, bool *output, + cudaStream_t cuda_stream); + template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const float *input, float *output, cudaStream_t cuda_stream); @@ -187,6 +214,16 @@ template void StridedSlice(const std::vector &input_shape, const std::ve template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const int *input, int *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + // NOLINTNEXTLINE + const std::vector &strides, const std::vector &output_shape, const short *input, + short *output, cudaStream_t cuda_stream); // NOLINT +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, const char *input, + char *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, const bool *input, + bool *output, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const float *dy, @@ -197,3 +234,13 @@ template void StridedSliceGrad(const std::vector &dy_shape, const std::vect template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const int *dy, int *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + // NOLINTNEXTLINE + const std::vector &strides, const std::vector &dx_shape, const short *dy, + short *dx, cudaStream_t cuda_stream); // NOLINT +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const char *dy, + char *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const bool *dy, + bool *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc index 35deb0cd8f..a9631307dd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc @@ -1,30 +1,36 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - CheckValid, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - CheckValidGpuKernel, float, bool) -MS_REG_GPU_KERNEL_TWO( - CheckValid, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - CheckValidGpuKernel, half, bool) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + CheckValid, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + CheckValidGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + CheckValid, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + CheckValidGpuKernel, half, bool) +MS_REG_GPU_KERNEL_TWO( + CheckValid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), + CheckValidGpuKernel, short, bool) // NOLINT +MS_REG_GPU_KERNEL_TWO( + CheckValid, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), + CheckValidGpuKernel, char, bool) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 105e0807cc..417441fb41 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1304,7 +1304,7 @@ class StridedSliceGrad(PrimitiveWithInfer): def __infer__(self, dy, shapex, begin, end, strides): args = {"dy": dy['dtype']} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) for idx, item in enumerate(shapex['value']): validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2ec3f12872..569a9c2adf 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -879,7 +879,7 @@ class Fill(PrimitiveWithInfer): validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) for idx, item in enumerate(dims['value']): validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name) - valid_types = [mstype.bool_, mstype.int8, mstype.int32, mstype.int64, + valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] validator.check_type_same({"value": dtype['value']}, valid_types, self.name) diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index b7e7991caa..ac0817608b 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -221,7 +221,7 @@ class CheckValid(PrimitiveWithInfer): return bboxes_shape[:-1] def infer_dtype(self, bboxes_type, metas_type): - valid_type = [mstype.float32, mstype.float16] + valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name) validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name) return mstype.bool_ diff --git a/tests/st/ops/gpu/test_check_valid_op.py b/tests/st/ops/gpu/test_check_valid_op.py index 2f30ecfc6e..937c5b838d 100644 --- a/tests/st/ops/gpu/test_check_valid_op.py +++ b/tests/st/ops/gpu/test_check_valid_op.py @@ -16,7 +16,6 @@ import numpy as np import pytest -import mindspore import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -31,24 +30,57 @@ class NetCheckValid(nn.Cell): def construct(self, anchor, image_metas): return self.valid(anchor, image_metas) +def check_valid(nptype): + anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], nptype) + image_metas = np.array([768, 1280, 1], nptype) + anchor_box = Tensor(anchor) + image_metas_box = Tensor(image_metas) + expect = np.array([True, False, False], np.bool) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + boundingbox_decode = NetCheckValid() + output = boundingbox_decode(anchor_box, image_metas_box) + assert np.array_equal(output.asnumpy(), expect) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + boundingbox_decode = NetCheckValid() + output = boundingbox_decode(anchor_box, image_metas_box) + assert np.array_equal(output.asnumpy(), expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_check_valid_float32(): + check_valid(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_check_valid_float16(): + check_valid(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_check_valid_int16(): + check_valid(np.int16) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_boundingbox_decode(): - anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], np.float32) - image_metas = np.array([768, 1280, 1], np.float32) - anchor_box = Tensor(anchor, mindspore.float32) - image_metas_box = Tensor(image_metas, mindspore.float32) - expect = np.array([True, False, False], np.bool_) +def test_check_valid_uint8(): + anchor = np.array([[5, 0, 10, 70], [2, 2, 8, 10], [1, 2, 30, 200]], np.uint8) + image_metas = np.array([76, 128, 1], np.uint8) + anchor_box = Tensor(anchor) + image_metas_box = Tensor(image_metas) + expect = np.array([True, True, False], np.bool) context.set_context(mode=context.GRAPH_MODE, device_target='GPU') boundingbox_decode = NetCheckValid() output = boundingbox_decode(anchor_box, image_metas_box) - diff = (output.asnumpy() == expect) - assert (diff == 1).all() + assert np.array_equal(output.asnumpy(), expect) context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') boundingbox_decode = NetCheckValid() output = boundingbox_decode(anchor_box, image_metas_box) - diff = (output.asnumpy() == expect) - assert (diff == 1).all() + assert np.array_equal(output.asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_concatv2_op.py b/tests/st/ops/gpu/test_concatv2_op.py index dc71e9622b..0b35652eec 100644 --- a/tests/st/ops/gpu/test_concatv2_op.py +++ b/tests/st/ops/gpu/test_concatv2_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,97 +24,163 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.ops import operations as P -context.set_context(device_target='GPU') - class ConcatV32(nn.Cell): - def __init__(self): + def __init__(self, nptype): super(ConcatV32, self).__init__() self.cat = P.Concat(axis=2) self.x1 = Parameter(initializer( - Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(np.float32)), [2, 2, 1]), name='x1') + Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(nptype)), [2, 2, 1]), name='x1') self.x2 = Parameter(initializer( - Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(np.float32)), [2, 2, 2]), name='x2') + Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(nptype)), [2, 2, 2]), name='x2') @ms_function def construct(self): return self.cat((self.x1, self.x2)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_axis32(): - cat = ConcatV32() +def axis32(nptype): + context.set_context(device_target='GPU') + + cat = ConcatV32(nptype) output = cat() - expect = [[[0., 0., 1.], - [1., 2., 3.]], - [[2., 4., 5.], - [3., 6., 7.]]] + expect = np.array([[[0., 0., 1.], + [1., 2., 3.]], + [[2., 4., 5.], + [3., 6., 7.]]]).astype(nptype) print(output) assert (output.asnumpy() == expect).all() +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis32_float32(): + axis32(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis32_int16(): + axis32(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis32_uint8(): + axis32(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis32_bool(): + axis32(np.bool) + class ConcatV43(nn.Cell): - def __init__(self): + def __init__(self, nptype): super(ConcatV43, self).__init__() self.cat = P.Concat(axis=3) self.x1 = Parameter(initializer( - Tensor(np.arange(2 * 2 * 2 * 2).reshape(2, 2, 2, 2).astype(np.float32)), [2, 2, 2, 2]), name='x1') + Tensor(np.arange(2 * 2 * 2 * 2).reshape(2, 2, 2, 2).astype(nptype)), [2, 2, 2, 2]), name='x1') self.x2 = Parameter(initializer( - Tensor(np.arange(2 * 2 * 2 * 3).reshape(2, 2, 2, 3).astype(np.float32)), [2, 2, 2, 3]), name='x2') + Tensor(np.arange(2 * 2 * 2 * 3).reshape(2, 2, 2, 3).astype(nptype)), [2, 2, 2, 3]), name='x2') @ms_function def construct(self): return self.cat((self.x1, self.x2)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_axis43(): - cat = ConcatV43() +def axis43(nptype): + context.set_context(device_target='GPU') + + cat = ConcatV43(nptype) output = cat() - expect = [[[[0., 1., 0., 1., 2.], - [2., 3., 3., 4., 5.]], - [[4., 5., 6., 7., 8.], - [6., 7., 9., 10., 11.]]], - [[[8., 9., 12., 13., 14.], - [10., 11., 15., 16., 17.]], - [[12., 13., 18., 19., 20.], - [14., 15., 21., 22., 23.]]]] + expect = np.array([[[[0., 1., 0., 1., 2.], + [2., 3., 3., 4., 5.]], + [[4., 5., 6., 7., 8.], + [6., 7., 9., 10., 11.]]], + [[[8., 9., 12., 13., 14.], + [10., 11., 15., 16., 17.]], + [[12., 13., 18., 19., 20.], + [14., 15., 21., 22., 23.]]]]).astype(nptype) assert (output.asnumpy() == expect).all() print(output) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis43_float32(): + axis43(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis43_int16(): + axis43(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis43_uint8(): + axis43(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis43_bool(): + axis43(np.bool) + + class ConcatV21(nn.Cell): - def __init__(self): + def __init__(self, nptype): super(ConcatV21, self).__init__() self.cat = P.Concat(axis=1) self.x1 = Parameter(initializer( - Tensor(np.arange(2 * 2).reshape(2, 2).astype(np.float32)), [2, 2]), name='x1') + Tensor(np.arange(2 * 2).reshape(2, 2).astype(nptype)), [2, 2]), name='x1') self.x2 = Parameter(initializer( - Tensor(np.arange(2 * 3).reshape(2, 3).astype(np.float32)), [2, 3]), name='x2') + Tensor(np.arange(2 * 3).reshape(2, 3).astype(nptype)), [2, 3]), name='x2') @ms_function def construct(self): return self.cat((self.x1, self.x2)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_axis21(): - cat = ConcatV21() +def axis21(nptype): + cat = ConcatV21(nptype) output = cat() - expect = [[0., 1., 0., 1., 2.], - [2., 3., 3., 4., 5.]] + expect = np.array([[0., 1., 0., 1., 2.], + [2., 3., 3., 4., 5.]]).astype(nptype) assert (output.asnumpy() == expect).all() print(output) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis21_float32(): + axis21(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis21_int16(): + axis21(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis21_uint8(): + axis21(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis21_bool(): + axis21(np.bool) + class Concat3INet(nn.Cell): def __init__(self): @@ -125,15 +191,12 @@ class Concat3INet(nn.Cell): return self.cat((x1, x2, x3)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_concat_3i(): +def concat_3i(nptype): cat = Concat3INet() - x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) - x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) - x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) + x1_np = np.random.randn(32, 4, 224, 224).astype(nptype) + x2_np = np.random.randn(32, 8, 224, 224).astype(nptype) + x3_np = np.random.randn(32, 10, 224, 224).astype(nptype) output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1) x1_ms = Tensor(x1_np) @@ -145,6 +208,42 @@ def test_concat_3i(): diff = output_ms.asnumpy() - output_np assert np.all(diff < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_3i_float32(): + concat_3i(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_3i_int16(): + concat_3i(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_3i_uint8(): + concat_3i(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_3i_bool(): + cat = Concat3INet() + + x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool) + x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool) + x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool) + output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + output_ms = cat(x1_ms, x2_ms, x3_ms) + + assert (output_ms.asnumpy() == output_np).all() + class Concat4INet(nn.Cell): def __init__(self): @@ -155,16 +254,13 @@ class Concat4INet(nn.Cell): return self.cat((x1, x2, x3, x4)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_concat_4i(): +def concat_4i(nptype): cat = Concat4INet() - x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) - x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) - x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) - x4_np = np.random.randn(32, 5, 224, 224).astype(np.float32) + x1_np = np.random.randn(32, 4, 224, 224).astype(nptype) + x2_np = np.random.randn(32, 8, 224, 224).astype(nptype) + x3_np = np.random.randn(32, 10, 224, 224).astype(nptype) + x4_np = np.random.randn(32, 5, 224, 224).astype(nptype) output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1) x1_ms = Tensor(x1_np) @@ -176,3 +272,41 @@ def test_concat_4i(): error = np.ones(shape=output_np.shape) * 10e-6 diff = output_ms.asnumpy() - output_np assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_4i_float32(): + concat_4i(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_4i_int16(): + concat_4i(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_4i_uint8(): + concat_4i(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_4i_bool(): + cat = Concat4INet() + + x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool) + x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool) + x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool) + x4_np = np.random.choice([True, False], (32, 5, 224, 224)).astype(np.bool) + output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + x4_ms = Tensor(x4_np) + output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms) + + assert (output_ms.asnumpy() == output_np).all() diff --git a/tests/st/ops/gpu/test_gathernd_op.py b/tests/st/ops/gpu/test_gathernd_op.py index c901eb08f2..ddacb20f24 100644 --- a/tests/st/ops/gpu/test_gathernd_op.py +++ b/tests/st/ops/gpu/test_gathernd_op.py @@ -28,28 +28,50 @@ class GatherNdNet(nn.Cell): def construct(self, x, indices): return self.gathernd(x, indices) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_gathernd0(): - x = Tensor(np.arange(3 * 2, dtype=np.float32).reshape(3, 2)) + +def gathernd0(nptype): + x = Tensor(np.arange(3 * 2, dtype=nptype).reshape(3, 2)) indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32)) - expect = np.array([3., 1.]) + expect = np.array([3, 1]).astype(nptype) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") gathernd = GatherNdNet() output = gathernd(x, indices) - error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 - diff = output.asnumpy() - expect - assert np.all(diff < error) - assert np.all(-diff < error) + assert np.array_equal(output.asnumpy(), expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd0_float32(): + gathernd0(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd0_float16(): + gathernd0(np.float16) @pytest.mark.level0 -@pytest.mark.platform_x86_gpu_traning +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_gathernd1(): - x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) +def test_gathernd0_int32(): + gathernd0(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd0_int16(): + gathernd0(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd0_uint8(): + gathernd0(np.uint8) + +def gathernd1(nptype): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=nptype).reshape(2, 3, 4, 5)) indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)] for k in range(3)] for l in range(2)], dtype='i4')) expect = np.array([[[[1., 3., 4.], @@ -80,21 +102,45 @@ def test_gathernd1(): [[101., 103., 104.], [106., 108., 109.], [111., 113., 114.], - [116., 118., 119.]]]]) + [116., 118., 119.]]]]).astype(nptype) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") gather = GatherNdNet() output = gather(x, indices) - error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 - diff = output.asnumpy() - expect - assert np.all(diff < error) - assert np.all(-diff < error) + assert np.array_equal(output.asnumpy(), expect) @pytest.mark.level0 -@pytest.mark.platform_x86_gpu_traning +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd1_float32(): + gathernd1(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd1_float16(): + gathernd1(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_gathernd2(): +def test_gathernd1_int32(): + gathernd1(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd1_int16(): + gathernd1(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd1_uint8(): + gathernd1(np.uint8) + +def gathernd2(nptype): x = Tensor(np.array([[4., 5., 4., 1., 5.], [4., 9., 5., 6., 4.], [9., 8., 4., 3., 6.], @@ -115,37 +161,48 @@ def test_gathernd2(): gathernd = GatherNdNet() output = gathernd(x, indices) - error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 - diff = output.asnumpy() - expect - assert np.all(diff < error) - assert np.all(-diff < error) + assert np.array_equal(output.asnumpy(), expect) @pytest.mark.level0 -@pytest.mark.platform_x86_gpu_traning +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_gathernd3(): - x = Tensor(np.array([[4, 5, 4, 1, 5], - [4, 9, 5, 6, 4], - [9, 8, 4, 3, 6], - [0, 4, 2, 2, 8], - [1, 8, 6, 2, 8], - [8, 1, 9, 7, 3], - [7, 9, 2, 5, 7], - [9, 8, 6, 8, 5], - [3, 7, 2, 7, 4], - [4, 2, 8, 2, 9]] - ).astype(np.int32)) +def test_gathernd2_float32(): + gathernd2(np.float32) - indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32)) - expect = np.array([[0, 0, 0, 0, 0], - [4, 9, 5, 6, 4], - [0, 0, 0, 0, 0]]) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd2_float16(): + gathernd2(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd2_int32(): + gathernd2(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd2_int16(): + gathernd2(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd2_uint8(): + gathernd2(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd_bool(): + x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool)) + indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int32)) + expect = np.array([True, False, False, False]).astype(np.bool) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") gathernd = GatherNdNet() output = gathernd(x, indices) - error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 - diff = output.asnumpy() - expect - assert np.all(diff < error) - assert np.all(-diff < error) + assert np.array_equal(output.asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_stridedslice_grad_op.py b/tests/st/ops/gpu/test_stridedslice_grad_op.py index 39d31c53cb..2faa32c706 100644 --- a/tests/st/ops/gpu/test_stridedslice_grad_op.py +++ b/tests/st/ops/gpu/test_stridedslice_grad_op.py @@ -22,9 +22,6 @@ from mindspore import Tensor from mindspore.ops import operations as P from mindspore.ops import composite as C -context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - - class StridedSliceNet(nn.Cell): def __init__(self, begin, end, stride, begin_mask=0, end_mask=0, ellipsis_mask=0): super(StridedSliceNet, self).__init__() @@ -45,11 +42,11 @@ class GradData(nn.Cell): def construct(self, x): return self.grad(self.network)(x) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_strided_slice_grad(): - x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32)) + +def strided_slice_grad(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype)) net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) dx = GradData(net)(x) expect = np.array([[[[0., 0., 0., 0., 0.], @@ -81,7 +78,7 @@ def test_strided_slice_grad(): [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.]]]]) + [0., 0., 0., 0., 0.]]]]).astype(nptype) assert np.allclose(dx[0].asnumpy(), expect) net = StridedSliceNet((1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2)) @@ -115,7 +112,7 @@ def test_strided_slice_grad(): [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.]]]]) + [0., 0., 0., 0., 0.]]]]).astype(nptype) assert np.allclose(dx[0].asnumpy(), expect) @@ -150,7 +147,7 @@ def test_strided_slice_grad(): [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.]]]]) + [0., 0., 0., 0., 0.]]]]).astype(nptype) assert np.allclose(dx[0].asnumpy(), expect) # ME infer fault @@ -253,7 +250,7 @@ def test_strided_slice_grad(): [[1., 1., 1., 1., 0.], [1., 1., 1., 1., 0.], [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.]]]]) + [1., 1., 1., 1., 0.]]]]).astype(nptype) assert np.allclose(dx[0].asnumpy(), expect) x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32)) @@ -272,10 +269,10 @@ def test_strided_slice_grad(): [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.]]]) + [0., 0., 0., 0., 0.]]]).astype(nptype) assert np.allclose(dx[0].asnumpy(), expect) - x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(np.float32)) + x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(nptype)) net = StridedSliceNet((0, 0, 0, 1, 1, 2, 2), (1, 1, 1, 2, 3, 3, 4), (1, 1, 1, 1, 1, 1, 1)) dx = GradData(net)(x) expect = np.array([[[[[[[0., 0., 0., 0., 0.], @@ -306,5 +303,29 @@ def test_strided_slice_grad(): [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 1., 1., 0.], - [0., 0., 0., 0., 0.]]]]]]]) + [0., 0., 0., 0., 0.]]]]]]]).astype(nptype) assert np.allclose(dx[0].asnumpy(), expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_float32(): + strided_slice_grad(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_int16(): + strided_slice_grad(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_uint8(): + strided_slice_grad(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_bool(): + strided_slice_grad(np.bool) diff --git a/tests/st/ops/gpu/test_stridedslice_op.py b/tests/st/ops/gpu/test_stridedslice_op.py index 098d18c9cb..61f7e479c1 100644 --- a/tests/st/ops/gpu/test_stridedslice_op.py +++ b/tests/st/ops/gpu/test_stridedslice_op.py @@ -20,33 +20,29 @@ import mindspore.context as context from mindspore import Tensor from mindspore.ops import operations as P -context.set_context(mode=context.GRAPH_MODE, device_target='GPU') +def strided_slice(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_stridedslice(): - x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32)) + x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype)) y = P.StridedSlice()(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) expect = np.array([[[[62, 63], [67, 68]], [[82, 83], - [87, 88]]]]) + [87, 88]]]]).astype(nptype) assert np.allclose(y.asnumpy(), expect) y = P.StridedSlice()(x, (1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2)) expect = np.array([[[[64, 62], [69, 67]], [[84, 82], - [89, 87]]]]) + [89, 87]]]]).astype(nptype) assert np.allclose(y.asnumpy(), expect) y = P.StridedSlice()(x, (1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1)) expect = np.array([[[[64, 63, 62], [69, 68, 67]], [[84, 83, 82], - [89, 88, 87]]]]) + [89, 88, 87]]]]).astype(nptype) assert np.allclose(y.asnumpy(), expect) # ME infer fault @@ -81,20 +77,20 @@ def test_stridedslice(): [[100, 101, 102, 103], [105, 106, 107, 108], [110, 111, 112, 113], - [115, 116, 117, 118]]]]) + [115, 116, 117, 118]]]]).astype(nptype) assert np.allclose(y.asnumpy(), expect) - x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32)) + x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(nptype)) y = P.StridedSlice()(x, (1, 0, 0), (2, -3, 3), (1, 1, 3)) - expect = np.array([[[20]]]) + expect = np.array([[[20]]]).astype(nptype) assert np.allclose(y.asnumpy(), expect) - x_np = np.arange(0, 4*5).reshape(4, 5).astype(np.float32) + x_np = np.arange(0, 4*5).reshape(4, 5).astype(nptype) y = Tensor(x_np)[:, ::-1] expect = x_np[:, ::-1] assert np.allclose(y.asnumpy(), expect) - x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(np.float32)) + x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(nptype)) y = P.StridedSlice()(x, (1, 0, 0, 2, 1, 2, 0), (2, 2, 2, 4, 2, 3, 2), (1, 1, 1, 1, 1, 1, 2)) expect = np.array([[[[[[[1498.]]], [[[1522.]]]], @@ -103,5 +99,29 @@ def test_stridedslice(): [[[[[1978.]]], [[[2002.]]]], [[[[2098.]]], - [[[2122.]]]]]]]) + [[[2122.]]]]]]]).astype(nptype) assert np.allclose(y.asnumpy(), expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_float32(): + strided_slice(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_int16(): + strided_slice(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_uint8(): + strided_slice(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_bool(): + strided_slice(np.bool)