!2698 GPU dropout rewrite

Merge pull request !2698 from VectorSL/drop
pull/2698/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 1066debc2b

@ -17,31 +17,59 @@
#include <stdint.h>
#include "dropout_impl.cuh"
#include "include/cuda_runtime.h"
__global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count,
template <typename T>
__global__ void DropoutForwardKernel(const T *input, T *mask, T *output, float *mask_f, size_t num_count,
float keep_prob) {
float scale = 1.f / keep_prob;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
mask[i] = mask[i] <= keep_prob;
output[i] = scale * input[i] * mask[i];
mask_f[i] = mask_f[i] <= keep_prob;
output[i] = scale * input[i] * mask_f[i];
mask[i] = mask_f[i];
}
}
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob,
template <>
__global__ void DropoutForwardKernel(const half *input, half *mask, half *output, float *mask_f,
size_t num_count, float keep_prob) {
half scale = __float2half(1.f / keep_prob);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
mask_f[i] = mask_f[i] <= keep_prob;
output[i] = scale * input[i] * __float2half(mask_f[i]);
mask[i] = __float2half(mask_f[i]);
}
}
template <typename T>
void DropoutForward(const T *input, T *mask, T *output, float *mask_f, size_t num_count, float drop_prob,
cudaStream_t cuda_stream) {
DropoutForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, num_count,
drop_prob);
DropoutForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, mask_f,
num_count, drop_prob);
}
__global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count,
template <typename T>
__global__ void DropoutBackwardKernel(const T *dy, const T *mask, T *dx, size_t num_count,
float keep_prob) {
float scale = 1.f / keep_prob;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
dx[i] = scale * dy[i] * mask[i];
}
}
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob,
template <>
__global__ void DropoutBackwardKernel(const half *dy, const half *mask, half *dx, size_t num_count,
float keep_prob) {
half scale = __float2half(1.f / keep_prob);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
dx[i] = scale * dy[i] * mask[i];
}
}
template <typename T>
void DropoutBackward(const T *dy, const T *mask, T *dx, size_t num_count, float drop_prob,
cudaStream_t cuda_stream) {
DropoutBackwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dy, mask, dx, num_count, drop_prob);
}
template void DropoutForward<float>(const float *input, float *mask, float *output, float *mask_f,
size_t num_count, float drop_prob, cudaStream_t cuda_stream);
template void DropoutForward<half>(const half *input, half *mask, half *output, float *mask_f,
size_t num_count, float drop_prob, cudaStream_t cuda_stream);
template void DropoutBackward<float>(const float *dy, const float *mask, float *dx, size_t num_count,
float drop_prob, cudaStream_t cuda_stream);
template void DropoutBackward<half>(const half *dy, const half *mask, half *dx, size_t num_count,
float drop_prob, cudaStream_t cuda_stream);

@ -18,9 +18,10 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
#include "device/gpu/cuda_common.h"
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float keep_prob,
template <typename T>
void DropoutForward(const T *input, T *mask, T *output, float *mask_f, size_t num_count, float keep_prob,
cudaStream_t cuda_stream);
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float keep_prob,
cudaStream_t cuda_stream);
template <typename T>
void DropoutBackward(const T *dy, const T *mask, T *dx, size_t num_count, float keep_prob, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_

@ -15,84 +15,16 @@
*/
#include "kernel/gpu/nn/dropout_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace mindspore {
namespace kernel {
DropoutGpuFwdKernel::DropoutGpuFwdKernel()
: cudnn_handle_(nullptr),
is_null_input_(false),
num_count_(0),
keep_prob_(0.0),
states_init_(false),
mask_generator_(nullptr) {}
DropoutGpuFwdKernel::~DropoutGpuFwdKernel() { DestroyResource(); }
const std::vector<size_t> &DropoutGpuFwdKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &DropoutGpuFwdKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &DropoutGpuFwdKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1.";
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
InitSizeLists();
return true;
}
num_count_ = 1;
for (size_t x : input_shape) {
num_count_ *= x;
}
keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob"));
InitSizeLists();
return true;
}
void DropoutGpuFwdKernel::InitResource() {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
}
void DropoutGpuFwdKernel::DestroyResource() noexcept {}
void DropoutGpuFwdKernel::InitSizeLists() {
size_t input_size = num_count_ * sizeof(float);
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size); // output size: the same with input size
output_size_list_.push_back(input_size); // mask size: the same with input size
}
bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
auto *input = reinterpret_cast<float *>(inputs[0]->addr);
auto *output = reinterpret_cast<float *>(outputs[0]->addr);
auto *mask = reinterpret_cast<float *>(outputs[1]->addr);
if (!states_init_) {
curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL));
states_init_ = true;
}
curandGenerateUniform(mask_generator_, mask, num_count_);
DropoutForward(input, mask, output, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL_ONE(
Dropout,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DropoutGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
Dropout,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DropoutGpuFwdKernel, half)
} // namespace kernel
} // namespace mindspore

@ -20,35 +20,88 @@
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
#include "include/curand.h"
namespace mindspore {
namespace kernel {
template <typename T>
class DropoutGpuFwdKernel : public GpuKernel {
public:
DropoutGpuFwdKernel();
DropoutGpuFwdKernel()
: cudnn_handle_(nullptr),
is_null_input_(false),
num_count_(0),
keep_prob_(0.0),
states_init_(false),
mask_generator_(nullptr) {}
~DropoutGpuFwdKernel() override;
~DropoutGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
const std::vector<size_t> &GetWorkspaceSizeList() const override;
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
T *mask = GetDeviceAddress<T>(outputs, 1);
float *mask_f = GetDeviceAddress<float>(workspace, 0);
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
if (!states_init_) {
curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL));
states_init_ = true;
}
// curandGen only support float or double for mask.
curandGenerateUniform(mask_generator_, mask_f, num_count_);
DropoutForward(input, mask, output, mask_f, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1.";
}
bool Init(const CNodePtr &kernel_node) override;
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
InitSizeLists();
return true;
}
num_count_ = 1;
for (size_t x : input_shape) {
num_count_ *= x;
}
keep_prob_ = GetAttr<float>(kernel_node, "keep_prob");
InitSizeLists();
return true;
}
protected:
void InitResource() override;
void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() override;
void InitSizeLists() override {
size_t input_size = num_count_ * sizeof(T);
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size); // output size: the same with input size
output_size_list_.push_back(input_size); // mask size: the same with input size
workspace_size_list_.push_back(num_count_ * sizeof(float)); // temp mask_f for curandGen
}
private:
void DestroyResource() noexcept;
cudnnHandle_t cudnn_handle_;
bool is_null_input_;
size_t num_count_;
@ -59,8 +112,6 @@ class DropoutGpuFwdKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_GPU_KERNEL(Dropout, DropoutGpuFwdKernel)
} // namespace kernel
} // namespace mindspore

@ -15,76 +15,16 @@
*/
#include "kernel/gpu/nn/dropout_grad_kernel.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace mindspore {
namespace kernel {
DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel()
: cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {}
DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); }
const std::vector<size_t> &DropoutGradGpuFwdKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &DropoutGradGpuFwdKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &DropoutGradGpuFwdKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuFwdKernel needs 2.";
return false;
}
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
InitSizeLists();
return true;
}
num_count_ = 1;
for (size_t x : input_shape) {
num_count_ *= x;
}
keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob"));
InitSizeLists();
return true;
}
void DropoutGradGpuFwdKernel::InitResource() {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
}
void DropoutGradGpuFwdKernel::DestroyResource() noexcept {}
void DropoutGradGpuFwdKernel::InitSizeLists() {
size_t dy_size = num_count_ * sizeof(float);
size_t mask_size = dy_size;
size_t dx_size = dy_size;
input_size_list_.push_back(dy_size);
input_size_list_.push_back(mask_size);
output_size_list_.push_back(dx_size);
}
bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
auto *dy = reinterpret_cast<float *>(inputs[0]->addr);
auto *mask = reinterpret_cast<float *>(inputs[1]->addr);
auto *dx = reinterpret_cast<float *>(outputs[0]->addr);
DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL_ONE(
DropoutGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DropoutGradGpuBwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
DropoutGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DropoutGradGpuBwdKernel, half)
} // namespace kernel
} // namespace mindspore

@ -20,28 +20,72 @@
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace mindspore {
namespace kernel {
class DropoutGradGpuFwdKernel : public GpuKernel {
template <typename T>
class DropoutGradGpuBwdKernel : public GpuKernel {
public:
DropoutGradGpuFwdKernel();
~DropoutGradGpuFwdKernel() override;
DropoutGradGpuBwdKernel() : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {}
~DropoutGradGpuBwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *dy = GetDeviceAddress<T>(inputs, 0);
T *mask = GetDeviceAddress<T>(inputs, 1);
T *dx = GetDeviceAddress<T>(outputs, 0);
DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuBwdKernel needs 2.";
return false;
}
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
InitSizeLists();
return true;
}
num_count_ = 1;
for (size_t x : input_shape) {
num_count_ *= x;
}
keep_prob_ = GetAttr<float>(kernel_node, "keep_prob");
InitSizeLists();
return true;
}
protected:
void InitResource() override;
void InitSizeLists() override;
void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() override {
size_t dy_size = num_count_ * sizeof(T);
size_t mask_size = dy_size;
size_t dx_size = dy_size;
private:
void DestroyResource() noexcept;
input_size_list_.push_back(dy_size);
input_size_list_.push_back(mask_size);
output_size_list_.push_back(dx_size);
}
private:
cudnnHandle_t cudnn_handle_;
bool is_null_input_;
size_t num_count_;
@ -50,8 +94,6 @@ class DropoutGradGpuFwdKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_GPU_KERNEL(DropoutGrad, DropoutGradGpuFwdKernel)
} // namespace kernel
} // namespace mindspore

@ -4460,6 +4460,7 @@ class Dropout(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
return x_dtype, x_dtype
@ -4494,6 +4495,8 @@ class DropoutGrad(PrimitiveWithInfer):
def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
return dy_dtype

Loading…
Cancel
Save