diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/linspace.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/linspace.cu new file mode 100644 index 0000000000..2fa18bf647 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/linspace.cu @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/linspace.cuh" +#include + +template +__global__ void LinSpaceKernel(const T *start, const T *stop, const size_t value_count, T *output) { + T add_value = ((*stop - *start) / (value_count - 1)); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < value_count; i += gridDim.x * blockDim.x) { + output[i] = *start + (add_value * i); + } +} +template +void calLinSpace(const T *start, const T *stop, const size_t value_count, T *output, cudaStream_t cuda_stream) { + LinSpaceKernel<<>>(start, stop, value_count, output); +} +template void calLinSpace(const float *start, const float *stop, const size_t value_count, float *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/linspace.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/linspace.cuh new file mode 100644 index 0000000000..59b232d7ad --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/linspace.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void calLinSpace(const T *start, const T *stop, const size_t value_count, T *output, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.cc new file mode 100644 index 0000000000..2231741d3a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/linspace.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LinSpace, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + LinSpaceGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.h new file mode 100644 index 0000000000..1d0a16c23a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.h @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/linspace.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class LinSpaceGpuKernel : public GpuKernel { + public: + LinSpaceGpuKernel() { ResetResource(); } + ~LinSpaceGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *start_addr = GetDeviceAddress(inputs, 0); + T *stop_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + calLinSpace(start_addr, stop_addr, value_count_, output_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but DynamicLinSpace needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but DynamicLinSpace needs 1 output."; + return false; + } + auto input_1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + auto input_2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); + // error checking input data + if ((input_1.size() != 0) || (input_2.size() != 0)) { + MS_LOG(ERROR) << "For LinShape " + << "both start and end must be 0-D Tensors. Got " << input_1.size() << " and " << input_2.size() + << "."; + return false; + } + auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); + if (value_count.size() != 1) { + MS_LOG(ERROR) << "For LinShape, output shape incorrect rank. Expect Rank: 1, got Rank: " << value_count.size() + << "."; + } + value_count_ = value_count[0]; + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + value_count_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(sizeof(T)); // Scalar tensor + input_size_list_.push_back(sizeof(T)); // Scalar tensor + output_size_list_.push_back(value_count_ * sizeof(T)); + } + + private: + size_t value_count_; + int num_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_ diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 39db3ef917..c319b43488 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -247,6 +247,8 @@ AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr & const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 0d8b301650..d5da5d94e8 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -167,5 +167,47 @@ AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const Pri const AbstractBasePtrList &args_spec_list) { return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); } + +AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto start = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(start); + MS_EXCEPTION_IF_NULL(start->shape()); + auto stop = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(stop); + MS_EXCEPTION_IF_NULL(stop->shape()); + (void)CheckTensorDType(start, {kFloat32}, "Input 0 (start) for LinSpace should be %s"); + (void)CheckTensorDType(stop, {kFloat32}, "Input 1 (stop) for LinSpace should be %s"); + ShapeVector shape; + ShapeVector max_shape; + ShapeVector min_shape; + int64_t num_val = 0; + // 3rd input is a Tensor when LinSpace is a dynamic shape operator + if (args_spec_list[2]->isa()) { + auto num = args_spec_list[2]->cast(); + MS_EXCEPTION_IF_NULL(num); + auto num_value_ptr = num->BuildValue(); + MS_EXCEPTION_IF_NULL(num_value_ptr); + auto num_tensor = num_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(num_tensor); + num_val = *static_cast(num_tensor->data_c()); + } else if (args_spec_list[2]->isa()) { + auto num = args_spec_list[2]->cast(); + num_val = GetValue(num->BuildValue()); + } else { + MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name(); + } + shape.emplace_back(num_val); + if (shape[0] < 0) { + MS_LOG(EXCEPTION) << "num must be >= 0 in LinSpace"; + } + max_shape.emplace_back(num_val); + min_shape.emplace_back(num_val); + AbstractTensorPtr ret = + std::make_shared(start->element(), std::make_shared(shape, min_shape, max_shape)); + return ret; +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index db7767ef25..550550ae85 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -45,6 +45,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimEqual, {InferImplEqual, true}}, {prim::kPrimMinimum, {InferImplMinimum, true}}, {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, + {prim::kPrimLinSpace, {InferImplLinSpace, true}}, // Array {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 630e1bb7a7..3e758ab1d6 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -241,6 +241,7 @@ inline const PrimitivePtr kPrimExp = std::make_shared("Exp"); inline const PrimitivePtr kPrimLog = std::make_shared("Log"); inline const PrimitivePtr kPrimRsqrt = std::make_shared("Rsqrt"); inline const PrimitivePtr kPrimSplitV = std::make_shared("SplitV"); +inline const PrimitivePtr kPrimLinSpace = std::make_shared("LinSpace"); // Statements inline const PrimitivePtr kPrimReturn = std::make_shared("return"); diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index b8cfd8ec73..88b51c7cc3 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -3946,15 +3946,19 @@ class Eps(PrimitiveWithInfer): class LinSpace(PrimitiveWithInfer): r""" - Generates values in an interval and returns the corresponding interpolation accroding to assist. + Generates values in an interval (inclusive of start and stop) and returns the corresponding + interpolated array with **num** number of ticks. Inputs: - - **start** (Tensor[float32]) - The start of interval, With shape of 0-D. - - **stop** (Tensor[float32]) - The end of interval, With shape of 0-D. - - **num** (int) - Ticks number in the interval, the ticks include start and stop value. + - **start** (Tensor[float32]) - Start value of interval, With shape of 0-D. + - **stop** (Tensor[float32]) - Last value of interval, With shape of 0-D. + - **num** (int) - Number of ticks in the interval, inclusive of start and stop. Outputs: - Tensor, has the same shape as `assist`. + Tensor, has the same shape as `start`. + + Supported Platforms: + ``Ascend`` ``GPU`` Examples: >>> linspace = P.LinSpace() diff --git a/tests/st/ops/gpu/test_lin_space.py b/tests/st/ops/gpu/test_lin_space.py new file mode 100644 index 0000000000..b950db1eab --- /dev/null +++ b/tests/st/ops/gpu/test_lin_space.py @@ -0,0 +1,99 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell +from mindspore.ops import operations as P + +class LinSpaceNet(Cell): + def __init__(self, num): + super(LinSpaceNet, self).__init__() + self.ls_op = P.LinSpace() + self.num = num + + def construct(self, start, stop): + output = self.ls_op(start, stop, self.num) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_lin_space_1(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + start_np = 5 + stop_np = 150 + num_np = 12 + start = Tensor(start_np, dtype=mstype.float32) + stop = Tensor(stop_np, dtype=mstype.float32) + num = num_np + ls_op = P.LinSpace() + result_ms = ls_op(start, stop, num).asnumpy() + result_np = np.linspace(start_np, stop_np, num_np) + assert np.allclose(result_ms, result_np) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_lin_shape_2(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + start_np = -25 + stop_np = 147 + num_np = 10 + start = Tensor(start_np, dtype=mstype.float32) + stop = Tensor(stop_np, dtype=mstype.float32) + num = num_np + ls_op = P.LinSpace() + result_ms = ls_op(start, stop, num).asnumpy() + result_np = np.linspace(start_np, stop_np, num_np) + assert np.allclose(result_ms, result_np) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_lin_shape_3(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + start_np = 25 + stop_np = -147 + num_np = 20 + start = Tensor(start_np, dtype=mstype.float32) + stop = Tensor(stop_np, dtype=mstype.float32) + net = LinSpaceNet(num_np) + result_ms = net(start, stop).asnumpy() + result_np = np.linspace(start_np, stop_np, num_np) + assert np.allclose(result_ms, result_np) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_lin_shape_4(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + start_np = -25.3 + stop_np = -147 + num_np = 36 + start = Tensor(start_np, dtype=mstype.float32) + stop = Tensor(stop_np, dtype=mstype.float32) + net = LinSpaceNet(num_np) + result_ms = net(start, stop).asnumpy() + result_np = np.linspace(start_np, stop_np, num_np) + assert np.allclose(result_ms, result_np)