From 507cc4ab1599e7294f34b75a808e817a1916ffb8 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Tue, 5 Jan 2021 23:49:05 -0500 Subject: [PATCH] range past segfault addressed comments changed default max_output_length to 1000000 change docstring fix ci change max_output_length to maxlen --- .../gpu/arrays/dynamic_range_gpu_kernel.cc | 54 ++++++++ .../gpu/arrays/dynamic_range_gpu_kernel.h | 121 ++++++++++++++++++ .../gpu/cuda_impl/dynamic_range_impl.cu | 76 +++++++++++ .../gpu/cuda_impl/dynamic_range_impl.cuh | 26 ++++ mindspore/core/abstract/infer_functions.h | 4 +- mindspore/core/abstract/prim_arrays.cc | 35 ++++- .../core/abstract/primitive_infer_map.cc | 3 +- mindspore/core/base/core_ops.h | 3 +- mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/array_ops.py | 62 ++++++++- tests/st/ops/gpu/test_range_op.py | 93 ++++++++++++++ 11 files changed, 475 insertions(+), 7 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh create mode 100644 tests/st/ops/gpu/test_range_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.cc new file mode 100644 index 0000000000..487e41b9cf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Range, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + DynamicRangeGpuKernel, float) + +MS_REG_GPU_KERNEL_ONE(Range, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + DynamicRangeGpuKernel, double) + +MS_REG_GPU_KERNEL_ONE(Range, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + DynamicRangeGpuKernel, int32_t) + +MS_REG_GPU_KERNEL_ONE(Range, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + DynamicRangeGpuKernel, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h new file mode 100644 index 0000000000..3ee237f1c8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_ + +#include + +#include + +#include "backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class DynamicRangeGpuKernel : public GpuKernel { + public: + DynamicRangeGpuKernel() { ResetResource(); } + ~DynamicRangeGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *range_start = GetDeviceAddress(inputs, 0); + T *range_end = GetDeviceAddress(inputs, 1); + T *range_delta = GetDeviceAddress(inputs, 2); + T *output_device_address = GetDeviceAddress(outputs, 0); + int64_t *output_shape_device_address = GetDeviceAddress(workspace, 0); + + stream_ptr_ = stream_ptr; + + CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address, + max_output_length_, reinterpret_cast(stream_ptr)); + + // use workspace[0] for actual output shape, we know it must be 1d + CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_, + cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t), + cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); + + return true; + } + + void PostExecute() override { + // required synchronize for PostExecute + CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "cudaStreamSynchronize failed"); + + std::vector output_type = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)}; + std::vector> output_shape = {{(size_t)output_shape_}}; + AnfAlgo::SetOutputInferTypeAndShape(output_type, output_shape, c_node_ptr_.get()); + } + + void ResetResource() noexcept override { + stream_ptr_ = nullptr; + c_node_ptr_ = nullptr; + output_shape_ = 0; + max_output_length_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_count != 3) { + MS_LOG(ERROR) << input_count << " inputs were provided, but DynamicRangeGpuKernel expects 3."; + return false; + } + + max_output_length_ = GetAttr(kernel_node, "maxlen"); + c_node_ptr_ = kernel_node; + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(max_output_length_ * sizeof(T)); + + // this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape. + workspace_size_list_.push_back(sizeof(int64_t)); + return; + } + + private: + void *stream_ptr_; + CNodePtr c_node_ptr_; + int64_t output_shape_; + int64_t max_output_length_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu new file mode 100644 index 0000000000..51a9051dd2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu @@ -0,0 +1,76 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dynamic_range_impl.cuh" + +#include + +#include "runtime/device/gpu/cuda_common.h" + +template +__device__ void CheckInputs(const T &start, const T &end, const T &delta) { + if (delta == 0) { + asm("trap;"); + } + + if (start < end && delta < 0) { + asm("trap;"); + } + + if (start > end && delta > 0) { + asm("trap;"); + } +} + +template +__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, + int64_t *output_shape, const int64_t max_output_size) { + T start = range_start[0]; + T end = range_end[0]; + T delta = range_delta[0]; + + CheckInputs(start, end, delta); + + int64_t real_output_shape = static_cast(ceil(static_cast(end - start) / delta)); + if (real_output_shape > max_output_size) { + asm("trap;"); + } + *output_shape = real_output_shape; + + size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; + for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) { + output[gt_id] = gt_id * delta + start; + } +} + +template +void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, + const int64_t max_output_size, cudaStream_t cuda_stream) { + Range<<>>(range_start, range_end, range_delta, + output, output_shape, max_output_size); +} + +template void CalRange(const int *range_start, const int *range_end, const int *range_delta, int *output, + int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); + +template void CalRange(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta, + int64_t *output, int64_t *output_shape, const int64_t max_output_size, + cudaStream_t cuda_stream); + +template void CalRange(const float *range_start, const float *range_end, const float *range_delta, float *output, + int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); +template void CalRange(const double *range_start, const double *range_end, const double *range_delta, + double *output, int64_t *output_shape, const int64_t max_output_size, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh new file mode 100644 index 0000000000..17b1fd8c0a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ + +#include + +template +void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, + const int64_t max_output_size, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 92dc45a53b..e0d836a084 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -285,6 +285,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index a243850099..521868e0e5 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -168,6 +168,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p if (max_shape.empty()) { max_shape = shape->shape(); } + auto ids = std::make_shared(input->element(), std::make_shared(ids_shape, min_shape, max_shape)); // Currently we choose the same data type as input for the idx. @@ -186,6 +187,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p if (idx_max_shape.empty()) { idx_max_shape = shape->shape(); } + auto ids_idx = std::make_shared(ids_idx_type, idx_shape); ids_idx->set_shape(std::make_shared(idx_shape, idx_min_shape, idx_max_shape)); // outputs: ids, ids_idx @@ -951,5 +953,36 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive ShapePtr output_shape = std::make_shared(lengths_shape, lengths_shape_min, lengths_shape_max); return std::make_shared(kBool, output_shape); } + +AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string &op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + AbstractTensorPtr range_start = CheckArg(op_name, args_spec_list, 0); + AbstractTensorPtr range_end = CheckArg(op_name, args_spec_list, 1); + AbstractTensorPtr range_delta = CheckArg(op_name, args_spec_list, 2); + + TypePtrList supported_types = {kInt64, kInt32, kFloat32, kFloat64}; + TypePtr range_start_type = CheckTensorDType(range_start, supported_types, "range_start input of Range should be %s"); + TypePtr range_end_type = CheckTensorDType(range_end, supported_types, "range_start input of Range should be %s"); + TypePtr range_delta_type = CheckTensorDType(range_delta, supported_types, "range_start input of Range should be %s"); + + // check all 3 inputs are same type + if (!IsIdentidityOrSubclass(range_start_type, range_end_type) || + !IsIdentidityOrSubclass(range_end_type, range_delta_type)) { + MS_LOG(EXCEPTION) << "All inputs must have same type, but got: " << args_spec_list[0]->type_name() << ", " + << args_spec_list[1]->type_name() << ", and " << args_spec_list[2]->type_name(); + } + + int64_t max_output_length = -1; + ValuePtr max_output_length_ptr = primitive->GetAttr("maxlen"); + max_output_length = GetValue(max_output_length_ptr); + ShapeVector output_shape = {Shape::SHP_ANY}; + ShapeVector min_shape = {1}; + ShapeVector max_shape = {max_output_length}; + ShapePtr shape = std::make_shared(output_shape, min_shape, max_shape); + + return std::make_shared(range_start_type, shape); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 7fcf58fe76..ef125aba0e 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,6 +81,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimMapUniform, {InferImplMapUniform, true}}, {prim::kPrimSplit, {InferImplSplit, true}}, {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, + {prim::kPrimRange, {InferImplRange, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8ab89abffc..c0d693e0bc 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -125,6 +125,7 @@ inline const PrimitivePtr kPrimScatterUpdate = std::make_shared("Scat inline const PrimitivePtr kPrimMapUniform = std::make_shared("MapUniform"); inline const PrimitivePtr kPrimSplit = std::make_shared("Split"); inline const PrimitivePtr kPrimSequenceMask = std::make_shared("SequenceMask"); +inline const PrimitivePtr kPrimRange = std::make_shared("Range"); // NN inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 00d37aefe5..9da3fba175 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, - Unique, GatherD, Identity) + Unique, GatherD, Identity, Range) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, @@ -402,6 +402,7 @@ __all__ = [ "ReLUV2", "SparseToDense", "MatrixInverse", + "Range", ] __all__.sort() diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 25e202c858..47924b891f 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1,6 +1,6 @@ # coding: utf-8 -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -4722,3 +4722,63 @@ class Identity(PrimitiveWithInfer): 'dtype': x['dtype'], 'value': None} return out + + +class Range(PrimitiveWithCheck): + r""" + Creates a sequence of numbers that begins at `start` and extends by increments of + `delta` up to but not including `limit`. + + The types of all 3 inputs must be the same. The type of the resulting tensor is + the same as the type of the inputs. + + Args: + maxlen (int): Memory that can fit `maxlen` many elements + will be allocated for the output. Optional, must be positive, defaults to 1000000. + If the output has more than `maxlen` elements, a runtime error + will occur. + + Inputs: + - **start** (Tensor) - A scalar Tensor. The first number in the sequence. Must have + type: int32 or float32 + - **limit** (Tensor) - A scalar Tensor. Upper limit of the sequence, exclusive. Must + have type: int32 or float32 + - **delta** (Tensor) - A scalar Tensor. Number that increments `start`. Must have + type: int32 or float32 + + Outputs: + A 1-D Tensor, with the same type as the inputs. + + Examples: + >>> start = Tensor(0) + >>> limit = Tensor(10) + >>> delta = Tensor(4) + >>> output = ops.Range()(start, limit, delta) + >>> print(output) + [0, 4, 8] + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + + @prim_attr_register + def __init__(self, maxlen=1000000): + self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output']) + validator.check_value_type("maxlen", maxlen, [int], self.name) + validator.check_positive_int(maxlen, "maxlen", self.name) + self.maxlen = maxlen + self.add_prim_attr('maxlen', maxlen) + + self.add_prim_attr("dynamic_shape_depends", [0]) + self.add_prim_attr("dynamic_shape_depends", [1]) + self.add_prim_attr("dynamic_shape_depends", [2]) + + def check_shape(self, start_shape, limit_shape, delta_shape): + validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name) + validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name) + validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name) + + def check_dtype(self, start_dtype, limit_dtype, delta_dtype): + valid_dtypes = [mstype.int32, mstype.float32] + inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype} + validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name) diff --git a/tests/st/ops/gpu/test_range_op.py b/tests/st/ops/gpu/test_range_op.py new file mode 100644 index 0000000000..de76b01525 --- /dev/null +++ b/tests/st/ops/gpu/test_range_op.py @@ -0,0 +1,93 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class RangeNet(nn.Cell): + def __init__(self): + super(RangeNet, self).__init__() + self.range = P.Range() + + def construct(self, s, e, d): + return self.range(s, e, d) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_range_int(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + range_net = RangeNet() + ms_out = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy() + np_expected = np.array([2, 3, 4]) + np.testing.assert_array_equal(ms_out, np_expected) + + range_net = RangeNet() + ms_out = range_net(Tensor(-24, mstype.int32), Tensor(1, mstype.int32), Tensor(4, mstype.int32)).asnumpy() + np_expected = np.array([-24, -20, -16, -12, -8, -4, 0]) + np.testing.assert_array_equal(ms_out, np_expected) + + range_net = RangeNet() + ms_out = range_net(Tensor(8, mstype.int32), Tensor(1, mstype.int32), Tensor(-1, mstype.int32)).asnumpy() + np_expected = np.array([8, 7, 6, 5, 4, 3, 2]) + np.testing.assert_array_equal(ms_out, np_expected) + + range_net = RangeNet() + ms_out = range_net(Tensor(3, mstype.int32), Tensor(-11, mstype.int32), Tensor(-5, mstype.int32)).asnumpy() + np_expected = np.array([3, -2, -7]) + np.testing.assert_array_equal(ms_out, np_expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_range_float(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + range_net = RangeNet() + ms_out = range_net(Tensor(2.3, mstype.float32), Tensor(5.5, mstype.float32), Tensor(1.2, mstype.float32)).asnumpy() + np_expected = np.array([2.3, 3.5, 4.7]) + np.testing.assert_array_almost_equal(ms_out, np_expected) + + range_net = RangeNet() + ms_out = range_net(Tensor(-4, mstype.float32), Tensor(-1, mstype.float32), Tensor(1.5, mstype.float32)).asnumpy() + np_expected = np.array([-4.0, -2.5]) + np.testing.assert_array_almost_equal(ms_out, np_expected) + + range_net = RangeNet() + ms_out = range_net(Tensor(8.0, mstype.float32), Tensor(1.0, mstype.float32), Tensor(-1.0, mstype.float32)).asnumpy() + np_expected = np.array([8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0]) + np.testing.assert_array_almost_equal(ms_out, np_expected) + + range_net = RangeNet() + ms_out = range_net(Tensor(1.5, mstype.float32), Tensor(-1, mstype.float32), Tensor(-18.9, mstype.float32)).asnumpy() + np_expected = np.array([1.5]) + np.testing.assert_array_almost_equal(ms_out, np_expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_range_invalid_max_output_length(): + with pytest.raises(ValueError): + _ = P.Range(0) + _ = P.Range(-1) + _ = P.Range(None) + _ = P.Range('5')