|
|
|
@ -8,7 +8,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
#include "cub/cub.cuh"
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __HIPCC__
|
|
|
|
|
#include <hipcub/hipcub.hpp>
|
|
|
|
|
namespace cub = hipcub;
|
|
|
|
|
#endif
|
|
|
|
|
#include "paddle/fluid/operators/math/cross_entropy.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
|
|
|
|
@ -214,6 +220,60 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
|
|
|
|
|
if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __HIPCC__ // @{ HIP Seperate Kernel for RowReductionForDiffMaxSum
|
|
|
|
|
// Note(qili93): HIP do not support return in kernel, need to seperate
|
|
|
|
|
// RowReductionForDiffMaxSum into two kernels below
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
static __global__ void RowReductionForSum(const T* logits_data, T* max_data,
|
|
|
|
|
T* softmax, int64_t d, int axis_dim) {
|
|
|
|
|
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
|
|
|
|
|
|
|
|
|
|
int64_t remain = d / axis_dim;
|
|
|
|
|
int64_t idx_n = blockIdx.x / remain;
|
|
|
|
|
int64_t idx_remain = blockIdx.x % remain;
|
|
|
|
|
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
|
|
|
|
|
int64_t end_idx = (idx_n + 1) * d;
|
|
|
|
|
|
|
|
|
|
auto block_max = max_data[blockIdx.x];
|
|
|
|
|
int64_t step = BlockDim * remain;
|
|
|
|
|
|
|
|
|
|
softmax[beg_idx] = logits_data[beg_idx] - block_max;
|
|
|
|
|
T diff_max_sum = exp_on_device(softmax[beg_idx]);
|
|
|
|
|
auto idx = beg_idx + step;
|
|
|
|
|
while (idx < end_idx) {
|
|
|
|
|
softmax[idx] = logits_data[idx] - block_max;
|
|
|
|
|
diff_max_sum += exp_on_device(softmax[idx]);
|
|
|
|
|
idx += step;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
diff_max_sum =
|
|
|
|
|
BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
|
|
|
|
|
if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
|
|
|
|
|
static __global__ void RowReductionForDiff(const T* logits_data, T* max_data,
|
|
|
|
|
T* softmax, int d, int axis_dim) {
|
|
|
|
|
int remain = d / axis_dim;
|
|
|
|
|
int idx_n = blockIdx.x / remain;
|
|
|
|
|
int idx_remain = blockIdx.x % remain;
|
|
|
|
|
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
|
|
|
|
|
int end_idx = (idx_n + 1) * d;
|
|
|
|
|
int step = BlockDim * remain;
|
|
|
|
|
|
|
|
|
|
T diff_max_sum = max_data[blockIdx.x];
|
|
|
|
|
softmax[beg_idx] -= diff_max_sum;
|
|
|
|
|
beg_idx += step;
|
|
|
|
|
while (beg_idx < end_idx) {
|
|
|
|
|
softmax[beg_idx] -= diff_max_sum;
|
|
|
|
|
beg_idx += step;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
|
|
|
|
|
}
|
|
|
|
|
#endif // @} End HIP Seperate Kernel for RowReductionForDiffMaxSum
|
|
|
|
|
|
|
|
|
|
// Make sure that BlockDim <= axis_dim
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
|
|
|
|
@ -345,6 +405,28 @@ static void HardLabelSoftmaxWithCrossEntropy(
|
|
|
|
|
int64_t grid_dim = n * d / axis_dim;
|
|
|
|
|
auto stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
#ifdef __HIPCC__
|
|
|
|
|
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
|
|
|
|
|
case BlockDim: { \
|
|
|
|
|
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>), \
|
|
|
|
|
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
|
|
|
|
|
loss_data, d, axis_dim); \
|
|
|
|
|
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>), \
|
|
|
|
|
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
|
|
|
|
|
loss_data, softmax_data, d, axis_dim); \
|
|
|
|
|
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForDiff<T, BlockDim>), \
|
|
|
|
|
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
|
|
|
|
|
loss_data, softmax_data, d, axis_dim); \
|
|
|
|
|
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d); \
|
|
|
|
|
if (ignore_idx >= 0 && ignore_idx < axis_dim) { \
|
|
|
|
|
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
|
|
|
|
|
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
|
|
|
|
|
} else { \
|
|
|
|
|
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
|
|
|
|
|
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
|
|
|
|
|
} \
|
|
|
|
|
} break
|
|
|
|
|
#else
|
|
|
|
|
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
|
|
|
|
|
case BlockDim: { \
|
|
|
|
|
RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
|
|
|
|
@ -361,6 +443,7 @@ static void HardLabelSoftmaxWithCrossEntropy(
|
|
|
|
|
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
|
|
|
|
|
} \
|
|
|
|
|
} break
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
|
|
|
|
@ -383,13 +466,27 @@ static void HardLabelSoftmaxWithCrossEntropy(
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void SoftmaxWithCrossEntropyFusedKernel(
|
|
|
|
|
const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
|
|
|
|
|
int64_t n, int64_t d, int axis_dim, cudaStream_t stream) {
|
|
|
|
|
int64_t n, int64_t d, int axis_dim, gpuStream_t stream) {
|
|
|
|
|
constexpr int kMaxBlockDim = 512;
|
|
|
|
|
int64_t block_dim = axis_dim >= kMaxBlockDim
|
|
|
|
|
? kMaxBlockDim
|
|
|
|
|
: (1 << static_cast<int>(std::log2(axis_dim)));
|
|
|
|
|
int64_t grid_dim = n * d / axis_dim;
|
|
|
|
|
|
|
|
|
|
#ifdef __HIPCC__
|
|
|
|
|
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
|
|
|
|
|
case BlockDim: \
|
|
|
|
|
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>), \
|
|
|
|
|
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
|
|
|
|
|
loss_data, d, axis_dim); \
|
|
|
|
|
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>), \
|
|
|
|
|
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
|
|
|
|
|
loss_data, softmax_data, d, axis_dim); \
|
|
|
|
|
hipLaunchKernelGGL( \
|
|
|
|
|
HIP_KERNEL_NAME(RowReductionForSoftmaxAndCrossEntropy<T, BlockDim>), \
|
|
|
|
|
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, labels_data, \
|
|
|
|
|
loss_data, softmax_data, d, axis_dim); \
|
|
|
|
|
break
|
|
|
|
|
#else
|
|
|
|
|
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
|
|
|
|
|
case BlockDim: \
|
|
|
|
|
RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
|
|
|
|
@ -400,6 +497,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(
|
|
|
|
|
T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
|
|
|
|
|
logits_data, labels_data, loss_data, softmax_data, d, axis_dim); \
|
|
|
|
|
break
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
|
|
|
|
@ -536,6 +634,16 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
// MIOPEN do not support double
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
|
|
|
|
|
ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
softmax_with_cross_entropy_grad,
|
|
|
|
|
ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
|
|
|
|
|
ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>);
|
|
|
|
|
#else
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
|
|
|
|
|
ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
|
|
|
|
@ -545,3 +653,4 @@ REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
|
|
|
|
|
ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
|
|
|
|
|
ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);
|
|
|
|
|
#endif
|
|
|
|
|