|
|
|
@ -17,13 +17,59 @@ limitations under the License. */
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/generator.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/platform/gpu_launch_config.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
// aligned vector generates vectorized load/store on CUDA
|
|
|
|
|
template <typename T, int Size>
|
|
|
|
|
struct alignas(sizeof(T) * Size) AlignedVector {
|
|
|
|
|
T val[Size];
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline int VectorizedSize(const T* pointer) {
|
|
|
|
|
uint64_t address = reinterpret_cast<uint64_t>(pointer);
|
|
|
|
|
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
|
|
|
|
|
if (address % vec4 == 0) {
|
|
|
|
|
return 4;
|
|
|
|
|
}
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
template <typename T, typename MaskType, int VecSize>
|
|
|
|
|
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
|
|
|
|
|
const T factor, const int64_t size,
|
|
|
|
|
T* dx) {
|
|
|
|
|
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
using LoadT = AlignedVector<T, VecSize>;
|
|
|
|
|
using MaskLoadT = AlignedVector<MaskType, VecSize>;
|
|
|
|
|
|
|
|
|
|
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
|
|
|
|
|
T dout_vec[VecSize];
|
|
|
|
|
LoadT* value = reinterpret_cast<LoadT*>(&dout_vec);
|
|
|
|
|
*value = *reinterpret_cast<const LoadT*>(&dout[i]);
|
|
|
|
|
|
|
|
|
|
T dx_vec[VecSize];
|
|
|
|
|
MaskType mask_vec[VecSize];
|
|
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < VecSize; ii++) {
|
|
|
|
|
dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*(reinterpret_cast<LoadT*>(&dx[i])) = *reinterpret_cast<LoadT*>(&dx_vec[0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
@ -119,6 +165,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* mask = context.Input<Tensor>("Mask");
|
|
|
|
|
grad_x->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto size = grad_x->numel();
|
|
|
|
|
|
|
|
|
|
auto M = EigenVector<uint8_t>::Flatten(*mask);
|
|
|
|
|
auto dX = EigenVector<T>::Flatten(*grad_x);
|
|
|
|
@ -126,7 +173,6 @@ class DropoutGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
|
|
auto& dropout_implementation =
|
|
|
|
|
context.Attr<std::string>("dropout_implementation");
|
|
|
|
|
if (dropout_implementation == "upscale_in_train") {
|
|
|
|
@ -134,8 +180,24 @@ class DropoutGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (dropout_prob == 1.0f) {
|
|
|
|
|
dX.device(place) = static_cast<T>(0) * dY;
|
|
|
|
|
} else {
|
|
|
|
|
dX.device(place) =
|
|
|
|
|
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
|
|
|
|
|
int vec_size = VectorizedSize<T>(grad_y->data<T>());
|
|
|
|
|
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
|
|
|
|
|
size % 4 == 0) {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
|
|
|
|
|
auto stream = context.cuda_device_context().stream();
|
|
|
|
|
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
|
|
|
|
|
context.cuda_device_context(), size);
|
|
|
|
|
DropoutGradCUDAKernel<
|
|
|
|
|
T, uint8_t,
|
|
|
|
|
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
|
|
|
|
|
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
|
|
|
|
|
grad_x->data<T>());
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
dX.device(place) =
|
|
|
|
|
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
dX.device(place) = dY * M.cast<T>();
|
|
|
|
|