|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_device_function.h"
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/gpu_launch_config.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
@ -39,6 +40,13 @@ using Tensor = framework::Tensor;
|
|
|
|
|
out_data, x->data<T>(), N, dim, dim); \
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \
|
|
|
|
|
case Log2Elements: \
|
|
|
|
|
softmax_warp_backward<T, float, Log2Elements><<< \
|
|
|
|
|
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
|
|
|
|
|
dx_data, mul_grad.data<T>(), out->data<T>(), N, dim, dim); \
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
static inline int SizeOutAxis(const int axis, DDim dims) {
|
|
|
|
|
int size = 1;
|
|
|
|
|
for (int i = axis + 1; i < dims.size(); i++) {
|
|
|
|
@ -199,6 +207,83 @@ __global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename AccT, int Log2Elements>
|
|
|
|
|
__global__ void softmax_warp_backward(T* gradInput, const T* grad,
|
|
|
|
|
const T* output, int batch_size,
|
|
|
|
|
int stride, int element_count) {
|
|
|
|
|
constexpr int next_power_of_two = 1 << Log2Elements;
|
|
|
|
|
constexpr int warp_size_softmax =
|
|
|
|
|
(next_power_of_two < 32) ? next_power_of_two : 32;
|
|
|
|
|
constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax;
|
|
|
|
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
|
|
|
|
|
|
|
|
|
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
|
|
|
|
|
|
|
|
|
int local_batches = batch_size - first_batch;
|
|
|
|
|
if (local_batches > WARP_BATCH) {
|
|
|
|
|
local_batches = WARP_BATCH;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int local_idx = threadIdx.x % warp_size_softmax;
|
|
|
|
|
|
|
|
|
|
int thread_offset = first_batch * stride + local_idx;
|
|
|
|
|
grad += thread_offset;
|
|
|
|
|
output += thread_offset;
|
|
|
|
|
gradInput += thread_offset;
|
|
|
|
|
|
|
|
|
|
// load data from global memory
|
|
|
|
|
AccT grad_reg[WARP_BATCH][WARP_ITERATIONS];
|
|
|
|
|
AccT output_reg[WARP_BATCH][WARP_ITERATIONS];
|
|
|
|
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
|
|
|
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
|
|
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
|
|
|
|
int element_index = local_idx + it * warp_size_softmax;
|
|
|
|
|
if (element_index < batch_element_count) {
|
|
|
|
|
grad_reg[i][it] =
|
|
|
|
|
static_cast<AccT>(grad[i * element_count + it * warp_size_softmax]);
|
|
|
|
|
output_reg[i][it] = static_cast<AccT>(
|
|
|
|
|
output[i * element_count + it * warp_size_softmax]);
|
|
|
|
|
} else {
|
|
|
|
|
grad_reg[i][it] = AccT(0);
|
|
|
|
|
output_reg[i][it] = AccT(0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AccT sum[WARP_BATCH];
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
|
|
|
sum[i] = grad_reg[i][0];
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
|
|
|
|
sum[i] += grad_reg[i][it];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
warp_reduce_sum<AccT, WARP_BATCH, warp_size_softmax>(sum);
|
|
|
|
|
|
|
|
|
|
// store result
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
|
|
|
if (i >= local_batches) break;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
|
|
|
|
int element_index = local_idx + it * warp_size_softmax;
|
|
|
|
|
if (element_index < element_count) {
|
|
|
|
|
// compute gradients
|
|
|
|
|
gradInput[i * element_count + it * warp_size_softmax] =
|
|
|
|
|
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void MultiplyCUDAKernel(T* C, const T* A, const T* B, int N) {
|
|
|
|
|
CUDA_KERNEL_LOOP(i, N) {
|
|
|
|
|
C[i] = static_cast<T>(static_cast<float>(A[i]) * static_cast<float>(B[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int VPT, int WARP_PER_BLOCK>
|
|
|
|
|
__global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src,
|
|
|
|
|
const int batch_size,
|
|
|
|
@ -340,28 +425,74 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
constexpr bool warp_softmax_available =
|
|
|
|
|
std::is_same<T, float>::value ||
|
|
|
|
|
std::is_same<T, platform::float16>::value;
|
|
|
|
|
if (D == 1 && dim == 128 && N % warps_per_block == 0 &&
|
|
|
|
|
warp_softmax_available) {
|
|
|
|
|
if (std::is_same<T, float>::value) {
|
|
|
|
|
VecSoftmaxBackward<
|
|
|
|
|
float, 4,
|
|
|
|
|
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
|
|
|
|
|
0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
dx->data<float>(), dout->data<float>(), out->data<float>(), N, dim);
|
|
|
|
|
} else if (std::is_same<T, platform::float16>::value) {
|
|
|
|
|
VecSoftmaxBackward<
|
|
|
|
|
platform::float16, 4,
|
|
|
|
|
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
|
|
|
|
|
0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
dx->data<platform::float16>(), dout->data<platform::float16>(),
|
|
|
|
|
out->data<platform::float16>(), N, dim);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
warp_softmax_available, true,
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Warp softmax backward is only available for fp32 and fp16"));
|
|
|
|
|
bool optimize = false;
|
|
|
|
|
if (D == 1 && warp_softmax_available) {
|
|
|
|
|
if (dim == 128 && N % warps_per_block == 0) {
|
|
|
|
|
optimize = true;
|
|
|
|
|
if (std::is_same<T, float>::value) {
|
|
|
|
|
VecSoftmaxBackward<float, 4, warps_per_block><<<
|
|
|
|
|
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
|
|
|
|
|
ctx.cuda_device_context().stream()>>>(dx->data<float>(),
|
|
|
|
|
dout->data<float>(),
|
|
|
|
|
out->data<float>(), N, dim);
|
|
|
|
|
} else if (std::is_same<T, platform::float16>::value) {
|
|
|
|
|
VecSoftmaxBackward<platform::float16, 4, warps_per_block><<<
|
|
|
|
|
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
|
|
|
|
|
ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
dx->data<platform::float16>(), dout->data<platform::float16>(),
|
|
|
|
|
out->data<platform::float16>(), N, dim);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
warp_softmax_available, true,
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Warp softmax backward is only available for fp32 and fp16"));
|
|
|
|
|
}
|
|
|
|
|
} else if (dim < 40 && dim % 32 != 0) {
|
|
|
|
|
optimize = true;
|
|
|
|
|
Tensor mul_grad;
|
|
|
|
|
int numel = N * dim;
|
|
|
|
|
mul_grad.mutable_data<T>({numel}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto stream = ctx.cuda_device_context().stream();
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto config = GetGpuLaunchConfig1D(dev_ctx, numel);
|
|
|
|
|
|
|
|
|
|
MultiplyCUDAKernel<T><<<config.block_per_grid.x,
|
|
|
|
|
config.thread_per_block.x, 0, stream>>>(
|
|
|
|
|
mul_grad.data<T>(), dout->data<T>(), out->data<T>(), numel);
|
|
|
|
|
|
|
|
|
|
int log2_elements = log2_ceil(dim);
|
|
|
|
|
const int next_power_of_two = 1 << log2_elements;
|
|
|
|
|
|
|
|
|
|
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
|
|
|
|
|
|
|
|
|
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
|
|
|
|
|
|
|
|
|
constexpr int threads_per_block = 128;
|
|
|
|
|
|
|
|
|
|
int warps_per_block = (threads_per_block / warp_size);
|
|
|
|
|
int batches_per_block = warps_per_block * batches_per_warp;
|
|
|
|
|
int blocks = (N + batches_per_block - 1) / batches_per_block;
|
|
|
|
|
dim3 threads(warp_size, warps_per_block, 1);
|
|
|
|
|
|
|
|
|
|
switch (log2_elements) {
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
if (!optimize) {
|
|
|
|
|
ScopedTensorDescriptor desc;
|
|
|
|
|
std::vector<int> tensor_dims = {N, dim, D, 1};
|
|
|
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
|
|
|