|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include <thrust/iterator/counting_iterator.h>
|
|
|
|
|
#include <thrust/random.h>
|
|
|
|
|
#include <thrust/transform.h>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/memory/memcpy.h"
|
|
|
|
|
#include "paddle/fluid/operators/dropout_op.h"
|
|
|
|
@ -26,60 +27,35 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename T, typename MaskType>
|
|
|
|
|
__global__ void RandomGenerator(const size_t n, const int seed,
|
|
|
|
|
const float dropout_prob, const T* src,
|
|
|
|
|
MaskType* mask_data, T* dst,
|
|
|
|
|
bool is_upscale_in_train) {
|
|
|
|
|
curandStatePhilox4_32_10_t state;
|
|
|
|
|
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
|
int step_size = 0;
|
|
|
|
|
// aligned vector generates vectorized load/store on CUDA
|
|
|
|
|
template <typename T, int Size>
|
|
|
|
|
struct alignas(sizeof(T) * Size) AlignedVector {
|
|
|
|
|
T val[Size];
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
MaskType mask;
|
|
|
|
|
T dest;
|
|
|
|
|
for (; idx < n; idx += blockDim.x * gridDim.x) {
|
|
|
|
|
T s = src[idx];
|
|
|
|
|
if (step_size == 0) {
|
|
|
|
|
curand_init(seed, idx, idx, &state);
|
|
|
|
|
step_size = blockDim.x * gridDim.x;
|
|
|
|
|
} else {
|
|
|
|
|
curand_init(seed, idx, step_size, &state);
|
|
|
|
|
}
|
|
|
|
|
if (curand_uniform(&state) < dropout_prob) {
|
|
|
|
|
mask = 0;
|
|
|
|
|
dest = 0;
|
|
|
|
|
} else {
|
|
|
|
|
mask = 1;
|
|
|
|
|
if (is_upscale_in_train) {
|
|
|
|
|
dest = s / static_cast<T>(1.0f - dropout_prob);
|
|
|
|
|
} else {
|
|
|
|
|
dest = s;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
mask_data[idx] = mask;
|
|
|
|
|
dst[idx] = dest;
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline int VectorizedSize(const T* pointer) {
|
|
|
|
|
uint64_t address = reinterpret_cast<uint64_t>(pointer);
|
|
|
|
|
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
|
|
|
|
|
if (address % vec4 == 0) {
|
|
|
|
|
return 4;
|
|
|
|
|
}
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename MaskType>
|
|
|
|
|
__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
|
|
|
|
|
const float dropout_prob, const T* src,
|
|
|
|
|
MaskType* mask_data, T* dst,
|
|
|
|
|
bool is_upscale_in_train) {
|
|
|
|
|
__global__ void RandomGenerator(const size_t n, uint64_t seed,
|
|
|
|
|
const float dropout_prob, const T* src,
|
|
|
|
|
MaskType* mask_data, T* dst,
|
|
|
|
|
bool is_upscale_in_train, uint64_t increment) {
|
|
|
|
|
curandStatePhilox4_32_10_t state;
|
|
|
|
|
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
|
int step_size = 0;
|
|
|
|
|
curand_init(seed, idx, increment, &state);
|
|
|
|
|
|
|
|
|
|
MaskType mask;
|
|
|
|
|
T dest;
|
|
|
|
|
for (; idx < n; idx += blockDim.x * gridDim.x) {
|
|
|
|
|
T s = src[idx];
|
|
|
|
|
if (step_size == 0) {
|
|
|
|
|
curand_init(seed[0], idx, idx, &state);
|
|
|
|
|
step_size = blockDim.x * gridDim.x;
|
|
|
|
|
} else {
|
|
|
|
|
curand_init(seed[0], idx, step_size, &state);
|
|
|
|
|
}
|
|
|
|
|
if (curand_uniform(&state) < dropout_prob) {
|
|
|
|
|
mask = 0;
|
|
|
|
|
dest = 0;
|
|
|
|
@ -96,39 +72,49 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename MaskType>
|
|
|
|
|
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
|
|
|
|
|
const float dropout_prob,
|
|
|
|
|
const T* src, MaskType* mask_data,
|
|
|
|
|
T* dst, bool is_upscale_in_train,
|
|
|
|
|
uint64_t increment) {
|
|
|
|
|
template <typename T, typename MaskType, int VecSize>
|
|
|
|
|
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
|
|
|
|
|
const float dropout_prob,
|
|
|
|
|
const T* src, MaskType* mask_data,
|
|
|
|
|
T* dst, bool is_upscale_in_train,
|
|
|
|
|
uint64_t increment) {
|
|
|
|
|
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
|
curandStatePhilox4_32_10_t state;
|
|
|
|
|
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
|
int step_size = 0;
|
|
|
|
|
curand_init(seed, idx, increment, &state);
|
|
|
|
|
|
|
|
|
|
MaskType mask;
|
|
|
|
|
T dest;
|
|
|
|
|
for (; idx < n; idx += blockDim.x * gridDim.x) {
|
|
|
|
|
T s = src[idx];
|
|
|
|
|
if (step_size == 0) {
|
|
|
|
|
curand_init(seed, idx, increment, &state);
|
|
|
|
|
step_size = blockDim.x * gridDim.x;
|
|
|
|
|
} else {
|
|
|
|
|
curand_init(seed, idx, increment, &state);
|
|
|
|
|
}
|
|
|
|
|
if (curand_uniform(&state) < dropout_prob) {
|
|
|
|
|
mask = 0;
|
|
|
|
|
dest = 0;
|
|
|
|
|
} else {
|
|
|
|
|
mask = 1;
|
|
|
|
|
if (is_upscale_in_train) {
|
|
|
|
|
dest = s / static_cast<T>(1.0f - dropout_prob);
|
|
|
|
|
using LoadT = AlignedVector<T, VecSize>;
|
|
|
|
|
using MaskLoadT = AlignedVector<MaskType, VecSize>;
|
|
|
|
|
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
|
|
|
|
|
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
|
|
|
|
|
T src_vec[VecSize];
|
|
|
|
|
LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
|
|
|
|
|
*value = *reinterpret_cast<const LoadT*>(&src[i]);
|
|
|
|
|
float4 rand = curand_uniform4(&state);
|
|
|
|
|
|
|
|
|
|
T dest_vec[VecSize];
|
|
|
|
|
MaskType mask_vec[VecSize];
|
|
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < VecSize; ii++) {
|
|
|
|
|
if ((&rand.x)[ii] < dropout_prob) {
|
|
|
|
|
dest_vec[ii] = 0;
|
|
|
|
|
mask_vec[ii] = 0;
|
|
|
|
|
} else {
|
|
|
|
|
dest = s;
|
|
|
|
|
if (is_upscale_in_train) {
|
|
|
|
|
dest_vec[ii] = src_vec[ii] * factor;
|
|
|
|
|
} else {
|
|
|
|
|
dest_vec[ii] = src_vec[ii];
|
|
|
|
|
}
|
|
|
|
|
mask_vec[ii] = 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
mask_data[idx] = mask;
|
|
|
|
|
dst[idx] = dest;
|
|
|
|
|
|
|
|
|
|
*(reinterpret_cast<LoadT*>(&dst[i])) =
|
|
|
|
|
*reinterpret_cast<LoadT*>(&dest_vec[0]);
|
|
|
|
|
*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
|
|
|
|
|
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -170,36 +156,57 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
int threads = 512;
|
|
|
|
|
int grid = (x_numel + threads - 1) / threads;
|
|
|
|
|
const auto& dev_ctx = context.cuda_device_context();
|
|
|
|
|
int blocks_per_sm =
|
|
|
|
|
dev_ctx.GetMaxPhysicalThreadCount() / dev_ctx.GetSMCount() / threads;
|
|
|
|
|
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
|
|
|
|
|
|
|
|
|
|
// increment is used to set the args(offset) of curand_init, which defines
|
|
|
|
|
// offset in subsequence.
|
|
|
|
|
// The detail:
|
|
|
|
|
// https://docs.nvidia.com/cuda/curand/device-api-overview.html
|
|
|
|
|
// Increment should be at least the number of curand() random numbers used
|
|
|
|
|
// in each thread to avoid the random number generated this time being the
|
|
|
|
|
// same as the previous calls.
|
|
|
|
|
uint64_t seed_data;
|
|
|
|
|
uint64_t increment;
|
|
|
|
|
int vec_size = VectorizedSize<T>(x_data);
|
|
|
|
|
auto offset =
|
|
|
|
|
((x_numel - 1) / (threads * grid * vec_size) + 1) * vec_size;
|
|
|
|
|
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
|
|
|
|
|
.GetDeviceId();
|
|
|
|
|
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
|
|
|
|
|
|
|
|
|
|
if (seed && platform::is_gpu_place(seed->place())) {
|
|
|
|
|
auto seed_gpu_data = seed->data<int>();
|
|
|
|
|
RandomGeneratorWithSeed<T, uint8_t><<<grid, threads, 0, stream>>>(
|
|
|
|
|
size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data,
|
|
|
|
|
upscale_in_train);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
int seed_data;
|
|
|
|
|
std::random_device rnd;
|
|
|
|
|
if (seed) {
|
|
|
|
|
seed_data = *(seed->data<int>());
|
|
|
|
|
framework::Tensor seed_cpu_tensor;
|
|
|
|
|
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
|
|
|
|
|
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
|
|
|
|
|
increment = offset;
|
|
|
|
|
} else if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
|
|
|
|
|
auto seed_offset = gen_cuda->IncrementOffset(offset);
|
|
|
|
|
seed_data = seed_offset.first;
|
|
|
|
|
increment = seed_offset.second;
|
|
|
|
|
} else {
|
|
|
|
|
seed_data =
|
|
|
|
|
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
|
|
|
|
|
if (seed) {
|
|
|
|
|
seed_data = *(seed->data<int>());
|
|
|
|
|
} else {
|
|
|
|
|
std::random_device rnd;
|
|
|
|
|
seed_data = context.Attr<bool>("fix_seed") ? context.Attr<int>("seed")
|
|
|
|
|
: rnd();
|
|
|
|
|
}
|
|
|
|
|
increment = offset;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
|
|
|
|
|
.GetDeviceId();
|
|
|
|
|
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
|
|
|
|
|
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
|
|
|
|
|
auto seed_offset = gen_cuda->IncrementOffset(1);
|
|
|
|
|
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
|
|
|
|
|
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
|
|
|
|
|
upscale_in_train, seed_offset.second);
|
|
|
|
|
return;
|
|
|
|
|
if (vec_size == 4) {
|
|
|
|
|
VectorizedRandomGenerator<T, uint8_t, 4><<<grid, threads, 0, stream>>>(
|
|
|
|
|
size, seed_data, dropout_prob, x_data, mask_data, y_data,
|
|
|
|
|
upscale_in_train, increment);
|
|
|
|
|
} else {
|
|
|
|
|
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
|
|
|
|
|
size, seed_data, dropout_prob, x_data, mask_data, y_data,
|
|
|
|
|
upscale_in_train, increment);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
|
|
|
|
|
size, seed_data, dropout_prob, x_data, mask_data, y_data,
|
|
|
|
|
upscale_in_train);
|
|
|
|
|
} else {
|
|
|
|
|
auto X = EigenMatrix<T>::Reshape(*x, 1);
|
|
|
|
|
auto Y = EigenMatrix<T>::Reshape(*y, 1);
|
|
|
|
|