|
|
|
@ -17,31 +17,59 @@
|
|
|
|
|
#include <stdint.h>
|
|
|
|
|
#include "dropout_impl.cuh"
|
|
|
|
|
#include "include/cuda_runtime.h"
|
|
|
|
|
|
|
|
|
|
__global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count,
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void DropoutForwardKernel(const T *input, T *mask, T *output, float *mask_f, size_t num_count,
|
|
|
|
|
float keep_prob) {
|
|
|
|
|
float scale = 1.f / keep_prob;
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
|
|
|
|
mask[i] = mask[i] <= keep_prob;
|
|
|
|
|
output[i] = scale * input[i] * mask[i];
|
|
|
|
|
mask_f[i] = mask_f[i] <= keep_prob;
|
|
|
|
|
output[i] = scale * input[i] * mask_f[i];
|
|
|
|
|
mask[i] = mask_f[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob,
|
|
|
|
|
template <>
|
|
|
|
|
__global__ void DropoutForwardKernel(const half *input, half *mask, half *output, float *mask_f,
|
|
|
|
|
size_t num_count, float keep_prob) {
|
|
|
|
|
half scale = __float2half(1.f / keep_prob);
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
|
|
|
|
mask_f[i] = mask_f[i] <= keep_prob;
|
|
|
|
|
output[i] = scale * input[i] * __float2half(mask_f[i]);
|
|
|
|
|
mask[i] = __float2half(mask_f[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
template <typename T>
|
|
|
|
|
void DropoutForward(const T *input, T *mask, T *output, float *mask_f, size_t num_count, float drop_prob,
|
|
|
|
|
cudaStream_t cuda_stream) {
|
|
|
|
|
DropoutForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, num_count,
|
|
|
|
|
drop_prob);
|
|
|
|
|
DropoutForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, mask_f,
|
|
|
|
|
num_count, drop_prob);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count,
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void DropoutBackwardKernel(const T *dy, const T *mask, T *dx, size_t num_count,
|
|
|
|
|
float keep_prob) {
|
|
|
|
|
float scale = 1.f / keep_prob;
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
|
|
|
|
dx[i] = scale * dy[i] * mask[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob,
|
|
|
|
|
template <>
|
|
|
|
|
__global__ void DropoutBackwardKernel(const half *dy, const half *mask, half *dx, size_t num_count,
|
|
|
|
|
float keep_prob) {
|
|
|
|
|
half scale = __float2half(1.f / keep_prob);
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
|
|
|
|
dx[i] = scale * dy[i] * mask[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
template <typename T>
|
|
|
|
|
void DropoutBackward(const T *dy, const T *mask, T *dx, size_t num_count, float drop_prob,
|
|
|
|
|
cudaStream_t cuda_stream) {
|
|
|
|
|
DropoutBackwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dy, mask, dx, num_count, drop_prob);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void DropoutForward<float>(const float *input, float *mask, float *output, float *mask_f,
|
|
|
|
|
size_t num_count, float drop_prob, cudaStream_t cuda_stream);
|
|
|
|
|
template void DropoutForward<half>(const half *input, half *mask, half *output, float *mask_f,
|
|
|
|
|
size_t num_count, float drop_prob, cudaStream_t cuda_stream);
|
|
|
|
|
template void DropoutBackward<float>(const float *dy, const float *mask, float *dx, size_t num_count,
|
|
|
|
|
float drop_prob, cudaStream_t cuda_stream);
|
|
|
|
|
template void DropoutBackward<half>(const half *dy, const half *mask, half *dx, size_t num_count,
|
|
|
|
|
float drop_prob, cudaStream_t cuda_stream);
|
|
|
|
|