diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.cc new file mode 100644 index 0000000000..4f1b01eaa3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + RangeGPUKernel, float) +MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + RangeGPUKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h new file mode 100644 index 0000000000..06bd29bd48 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_ +#include <vector> +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh" +namespace mindspore { +namespace kernel { +template <typename T> +class RangeGPUKernel : public GpuKernel { + public: + RangeGPUKernel() : input_size_(0), output_size_(0), start_(0.), limit_(1.), delta_(1.) {} + ~RangeGPUKernel() = default; + + const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } + const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } + const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, + const std::vector<AddressPtr> &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress<T>(inputs, 0); + T *output = GetDeviceAddress<T>(outputs, 0); + int size = SizeToInt(input_size_ / sizeof(T)); + CalRange(size, start_, limit_, delta_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but Range needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but Range needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto shape_size = input_shape.size(); + input_size_ = 1; + for (size_t i = 0; i < shape_size; i++) { + input_size_ *= input_shape[i]; + } + input_size_ *= sizeof(T); + output_size_ = input_size_; + start_ = GetAttr<float>(kernel_node, "start"); + limit_ = GetAttr<float>(kernel_node, "limit"); + delta_ = GetAttr<float>(kernel_node, "delta"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + return; + } + + private: + std::vector<size_t> input_size_list_; + std::vector<size_t> output_size_list_; + std::vector<size_t> workspace_size_list_; + size_t input_size_; + size_t output_size_; + float start_; + float limit_; + float delta_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cu new file mode 100644 index 0000000000..a2dfb407c3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cu @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <cuda_runtime.h> +#include "range_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template <typename T> +__global__ void Range(const int size, const float start, const float limit, const float delta, const T *input, + T *output) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = input[pos] * delta + start; + } +} + +template <typename T> +void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output, + cudaStream_t cuda_stream) { + Range<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, start, limit, delta, input, output); + return; +} +template void CalRange<float>(const int size, const float start, const float limit, const float delta, + const float *input, float *output, cudaStream_t cuda_stream); + +template void CalRange<int>(const int size, const float start, const float limit, const float delta, const int *input, + int *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh new file mode 100644 index 0000000000..2d0aabc5d4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_ + +template <typename T> +void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 9219841ff7..98901058c6 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================ """Categorical Distribution""" -import numpy as np from mindspore.ops import operations as P +import mindspore.nn as nn from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error @@ -119,17 +119,19 @@ class Categorical(Distribution): """ return self._probs - def _sample(self, sample_shape=(1,)): + def _sample(self, sample_shape=()): """ Sampling. Args: - sample_shape (tuple): shape of the sample. Default: (1,). + sample_shape (tuple): shape of the sample. Default: (). Returns: Tensor, shape is shape(probs)[:-1] + sample_shape """ self.checktuple(sample_shape, 'shape') + if sample_shape == (): + sample_shape = (1,) num_sample = 1 for i in sample_shape: num_sample *= i @@ -184,16 +186,15 @@ class Categorical(Distribution): if value is not None: check_tensor_type("value", value, [mstype.float32, bool, mstype.int32]) value = self.expandim(self.cast(value, mstype.float32), -1) - index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32)) - index = self.expandim(index, -1) - logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0) - broad_shape = self._broad_cast_shape(value, logits) + broad_shape = self._broad_cast_shape(value, self._logits) broad = P.BroadcastTo(broad_shape) - value = broad(value)[..., :1] - index = broad(index)[..., :1] + logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1])) + value = self.reshape(broad(value)[..., :1], (-1, 1)) + index = nn.Range(0., self.shape(value)[0], 1)() + index = self.reshape(index, (-1, 1)) value = self.concat((index, value)) value = self.cast(value, mstype.int32) - return self.gather(logits, value) + return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1]) return None def _entropy(self): @@ -211,7 +212,7 @@ class Categorical(Distribution): Enumerate categories. """ num_events = self._num_events - values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32) + values = nn.Range(0., num_events, 1)() values = self.reshape(values, (num_events, 1)) if expand: values = P.BroadcastTo((num_events, self._batch_shape))(values) diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 4d6fbd63ae..0b07c0e08b 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -450,8 +450,8 @@ class Multinomial(PrimitiveWithInfer): Examples: >>> input = Tensor([0., 9., 4., 0.], mstype.float32) - >>> multinomial = P.Multinomial(seed=10) - >>> output = multinomial(input, 2, True) + >>> multinomial = P.Multinomial(replacement=True, seed=10) + >>> output = multinomial(input, 2) """ @prim_attr_register