From cccb230f7bb0130ca052ecefa2fe8071931ceaf0 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Wed, 8 Jul 2020 00:29:33 -0300 Subject: [PATCH] Add random normal cuda implementation on GPU --- .../kernel/gpu/cuda_impl/random_op_impl.cu | 42 ++++++ .../kernel/gpu/cuda_impl/random_op_impl.cuh | 26 ++++ .../kernel/gpu/math/random_op_gpu_kernel.cc | 24 ++++ .../kernel/gpu/math/random_op_gpu_kernel.h | 121 ++++++++++++++++++ 4 files changed, 213 insertions(+) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.h diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cu new file mode 100644 index 0000000000..6f99394562 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cu @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "random_op_impl.cuh" +template +__global__ void NormalKernel(int seed, curandState *globalState, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + curand_init(seed, i, 0, &globalState[i]); + output[i] = curand_normal(&globalState[i]); + } + return; +} + +template +void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) { + int RNG_seed = 0; + if (seed2 != 0) { + RNG_seed = seed2; + } else if (seed != 0) { + RNG_seed = seed; + } else { + RNG_seed = time(NULL); + } + NormalKernel<<>>(RNG_seed, globalState, output, count); + return; +} + +template void StandardNormal(int seed, int seed2, curandState *globalState, + float *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cuh new file mode 100644 index 0000000000..5e9110a1bc --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ + +#include +#include "device/gpu/cuda_common.h" + +template +void StandardNormal(int seed, int seed2, curandState *globalState, + T *output, size_t count, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.cc new file mode 100644 index 0000000000..d54fe285c2 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/math/random_op_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + RandomOpGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.h new file mode 100644 index 0000000000..3767cd9fc8 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/random_op_impl.cuh" + +namespace mindspore { +namespace kernel { +enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 }; + +const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}}; +template +class RandomOpGpuKernel : public GpuKernel { + public: + RandomOpGpuKernel() + : random_op_type_(RANDOM_OP_INVALID_TYPE), + input_size_0_(0), + output_size_(sizeof(T)), + workspace_size_(sizeof(curandState)) {} + ~RandomOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + void *workspace_addr = GetDeviceAddress(workspace, 0); + curandState *devStates = reinterpret_cast(workspace_addr); + T *output_addr = GetDeviceAddress(outputs, 0); + + switch (random_op_type_) { + case RANDOM_OP_NORMAL: { + StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kRandomOpTypeMap.find(kernel_name); + if (iter == kRandomOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "Random operation " << kernel_name << " is not supported."; + } else { + random_op_type_ = iter->second; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output."; + return false; + } + auto input_shape_0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape_0.size(); i++) { + input_size_0_ += input_shape_0[i]; + } + input_size_0_ *= sizeof(int); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + workspace_size_ *= output_shape[i]; + } + seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); + seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_0_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + } + + private: + RandomOptype random_op_type_; + size_t input_size_0_; + size_t output_size_; + size_t workspace_size_; + int seed_; + int seed2_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_