diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.cc
new file mode 100644
index 0000000000..4f1b01eaa3
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.cc
@@ -0,0 +1,26 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h"
+
+namespace mindspore {
+namespace kernel {
+MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
+                      RangeGPUKernel, float)
+MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
+                      RangeGPUKernel, int)
+}  // namespace kernel
+}  // namespace mindspore
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h
new file mode 100644
index 0000000000..06bd29bd48
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h
@@ -0,0 +1,89 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_
+#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_
+#include <vector>
+#include "backend/kernel_compiler/gpu/gpu_kernel.h"
+#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
+#include "backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh"
+namespace mindspore {
+namespace kernel {
+template <typename T>
+class RangeGPUKernel : public GpuKernel {
+ public:
+  RangeGPUKernel() : input_size_(0), output_size_(0), start_(0.), limit_(1.), delta_(1.) {}
+  ~RangeGPUKernel() = default;
+
+  const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
+  const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
+  const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
+
+  bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
+              const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
+    T *input = GetDeviceAddress<T>(inputs, 0);
+    T *output = GetDeviceAddress<T>(outputs, 0);
+    int size = SizeToInt(input_size_ / sizeof(T));
+    CalRange(size, start_, limit_, delta_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
+    return true;
+  }
+
+  bool Init(const CNodePtr &kernel_node) override {
+    size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
+    if (input_num != 1) {
+      MS_LOG(ERROR) << "Input number is " << input_num << ", but Range needs 1 input.";
+      return false;
+    }
+    size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
+    if (output_num != 1) {
+      MS_LOG(ERROR) << "Output number is " << output_num << ", but Range needs 1 output.";
+      return false;
+    }
+    auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
+    auto shape_size = input_shape.size();
+    input_size_ = 1;
+    for (size_t i = 0; i < shape_size; i++) {
+      input_size_ *= input_shape[i];
+    }
+    input_size_ *= sizeof(T);
+    output_size_ = input_size_;
+    start_ = GetAttr<float>(kernel_node, "start");
+    limit_ = GetAttr<float>(kernel_node, "limit");
+    delta_ = GetAttr<float>(kernel_node, "delta");
+    InitSizeLists();
+    return true;
+  }
+
+ protected:
+  void InitSizeLists() override {
+    input_size_list_.push_back(input_size_);
+    output_size_list_.push_back(output_size_);
+    return;
+  }
+
+ private:
+  std::vector<size_t> input_size_list_;
+  std::vector<size_t> output_size_list_;
+  std::vector<size_t> workspace_size_list_;
+  size_t input_size_;
+  size_t output_size_;
+  float start_;
+  float limit_;
+  float delta_;
+};
+}  // namespace kernel
+}  // namespace mindspore
+#endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cu
new file mode 100644
index 0000000000..a2dfb407c3
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cu
@@ -0,0 +1,39 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cuda_runtime.h>
+#include "range_impl.cuh"
+#include "runtime/device/gpu/cuda_common.h"
+
+template <typename T>
+__global__ void Range(const int size, const float start, const float limit, const float delta, const T *input,
+                      T *output) {
+  for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
+    output[pos] = input[pos] * delta + start;
+  }
+}
+
+template <typename T>
+void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output,
+              cudaStream_t cuda_stream) {
+  Range<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, start, limit, delta, input, output);
+  return;
+}
+template void CalRange<float>(const int size, const float start, const float limit, const float delta,
+                              const float *input, float *output, cudaStream_t cuda_stream);
+
+template void CalRange<int>(const int size, const float start, const float limit, const float delta, const int *input,
+                            int *output, cudaStream_t cuda_stream);
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh
new file mode 100644
index 0000000000..2d0aabc5d4
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh
@@ -0,0 +1,23 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_
+#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_
+
+template <typename T>
+void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output,
+              cudaStream_t cuda_stream);
+#endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH
diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py
index 9219841ff7..98901058c6 100644
--- a/mindspore/nn/probability/distribution/categorical.py
+++ b/mindspore/nn/probability/distribution/categorical.py
@@ -13,8 +13,8 @@
 # limitations under the License.
 # ============================================================================
 """Categorical Distribution"""
-import numpy as np
 from mindspore.ops import operations as P
+import mindspore.nn as nn
 from mindspore.common import dtype as mstype
 from .distribution import Distribution
 from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error
@@ -119,17 +119,19 @@ class Categorical(Distribution):
         """
         return self._probs
 
-    def _sample(self, sample_shape=(1,)):
+    def _sample(self, sample_shape=()):
         """
         Sampling.
 
         Args:
-            sample_shape (tuple): shape of the sample. Default: (1,).
+            sample_shape (tuple): shape of the sample. Default: ().
 
         Returns:
             Tensor, shape is shape(probs)[:-1] + sample_shape
         """
         self.checktuple(sample_shape, 'shape')
+        if sample_shape == ():
+            sample_shape = (1,)
         num_sample = 1
         for i in sample_shape:
             num_sample *= i
@@ -184,16 +186,15 @@ class Categorical(Distribution):
         if value is not None:
             check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
             value = self.expandim(self.cast(value, mstype.float32), -1)
-            index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32))
-            index = self.expandim(index, -1)
-            logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0)
-            broad_shape = self._broad_cast_shape(value, logits)
+            broad_shape = self._broad_cast_shape(value, self._logits)
             broad = P.BroadcastTo(broad_shape)
-            value = broad(value)[..., :1]
-            index = broad(index)[..., :1]
+            logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1]))
+            value = self.reshape(broad(value)[..., :1], (-1, 1))
+            index = nn.Range(0., self.shape(value)[0], 1)()
+            index = self.reshape(index, (-1, 1))
             value = self.concat((index, value))
             value = self.cast(value, mstype.int32)
-            return self.gather(logits, value)
+            return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1])
         return None
 
     def _entropy(self):
@@ -211,7 +212,7 @@ class Categorical(Distribution):
        Enumerate categories.
        """
         num_events = self._num_events
-        values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32)
+        values = nn.Range(0., num_events, 1)()
         values = self.reshape(values, (num_events, 1))
         if expand:
             values = P.BroadcastTo((num_events, self._batch_shape))(values)
diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py
index 4d6fbd63ae..0b07c0e08b 100644
--- a/mindspore/ops/operations/random_ops.py
+++ b/mindspore/ops/operations/random_ops.py
@@ -450,8 +450,8 @@ class Multinomial(PrimitiveWithInfer):
 
     Examples:
         >>> input = Tensor([0., 9., 4., 0.], mstype.float32)
-        >>> multinomial = P.Multinomial(seed=10)
-        >>> output = multinomial(input, 2, True)
+        >>> multinomial = P.Multinomial(replacement=True, seed=10)
+        >>> output = multinomial(input, 2)
     """
 
     @prim_attr_register