|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
|
#include "paddle/fluid/operators/math/cross_entropy.h"
|
|
|
|
|
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/for_range.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -117,8 +118,8 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
|
|
|
|
|
// Make sure that BlockDim <= feature_size
|
|
|
|
|
// This kernel is used to calculate the max element of each row
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
__global__ void RowReductionForMax(const T* logits_data, T* max_data,
|
|
|
|
|
int feature_size) {
|
|
|
|
|
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
|
|
|
|
|
int feature_size) {
|
|
|
|
|
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
|
|
|
|
|
|
|
|
|
|
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
|
|
|
|
@ -141,9 +142,10 @@ __global__ void RowReductionForMax(const T* logits_data, T* max_data,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Make sure that BlockDim <= feature_size
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
__global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
|
|
|
|
|
T* softmax, int feature_size) {
|
|
|
|
|
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
|
|
|
|
|
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
|
|
|
|
|
T* max_data, T* softmax,
|
|
|
|
|
int feature_size) {
|
|
|
|
|
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
|
|
|
|
|
|
|
|
|
|
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
|
|
|
|
@ -153,24 +155,34 @@ __global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
|
|
|
|
|
|
|
|
|
|
softmax[beg_idx] = logits_data[beg_idx] - block_max;
|
|
|
|
|
T diff_max_sum = real_exp(softmax[beg_idx]);
|
|
|
|
|
beg_idx += BlockDim;
|
|
|
|
|
while (beg_idx < end_idx) {
|
|
|
|
|
softmax[beg_idx] = logits_data[beg_idx] - block_max;
|
|
|
|
|
diff_max_sum += real_exp(softmax[beg_idx]);
|
|
|
|
|
beg_idx += BlockDim;
|
|
|
|
|
auto idx = beg_idx + BlockDim;
|
|
|
|
|
while (idx < end_idx) {
|
|
|
|
|
softmax[idx] = logits_data[idx] - block_max;
|
|
|
|
|
diff_max_sum += real_exp(softmax[idx]);
|
|
|
|
|
idx += BlockDim;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
diff_max_sum =
|
|
|
|
|
BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
|
|
|
|
|
if (threadIdx.x == 0) max_data[blockIdx.x] = real_log(diff_max_sum);
|
|
|
|
|
|
|
|
|
|
if (!CalculateLogSoftmax) return;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
diff_max_sum = max_data[blockIdx.x];
|
|
|
|
|
softmax[beg_idx] -= diff_max_sum;
|
|
|
|
|
beg_idx += BlockDim;
|
|
|
|
|
while (beg_idx < end_idx) {
|
|
|
|
|
softmax[beg_idx] -= diff_max_sum;
|
|
|
|
|
beg_idx += BlockDim;
|
|
|
|
|
}
|
|
|
|
|
if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Make sure that BlockDim <= feature_size
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
__global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
|
|
|
|
|
const T* labels_data,
|
|
|
|
|
T* loss_data, T* softmax,
|
|
|
|
|
int feature_size) {
|
|
|
|
|
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
|
|
|
|
|
const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
|
|
|
|
|
int feature_size) {
|
|
|
|
|
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
|
|
|
|
|
|
|
|
|
|
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
|
|
|
|
@ -194,11 +206,134 @@ __global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int batch_size) {
|
|
|
|
|
struct HardLabelSoftmaxWithCrossEntropyFunctor {
|
|
|
|
|
public:
|
|
|
|
|
HardLabelSoftmaxWithCrossEntropyFunctor(const T* logits,
|
|
|
|
|
const int64_t* labels, T* loss,
|
|
|
|
|
T* log_softmax, int feature_size)
|
|
|
|
|
: logits_(logits),
|
|
|
|
|
labels_(labels),
|
|
|
|
|
loss_(loss),
|
|
|
|
|
log_softmax_(log_softmax),
|
|
|
|
|
feature_size_(feature_size) {}
|
|
|
|
|
|
|
|
|
|
__device__ void operator()(int idx) const {
|
|
|
|
|
auto row_idx = idx / feature_size_;
|
|
|
|
|
auto col_idx = idx % feature_size_;
|
|
|
|
|
if (col_idx != labels_[row_idx]) {
|
|
|
|
|
log_softmax_[idx] = real_exp(log_softmax_[idx]);
|
|
|
|
|
} else {
|
|
|
|
|
auto softmax = log_softmax_[idx];
|
|
|
|
|
log_softmax_[idx] = real_exp(softmax);
|
|
|
|
|
loss_[row_idx] = -softmax;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* logits_;
|
|
|
|
|
const int64_t* labels_;
|
|
|
|
|
T* loss_;
|
|
|
|
|
T* log_softmax_;
|
|
|
|
|
int feature_size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
|
|
|
|
|
public:
|
|
|
|
|
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const T* logits,
|
|
|
|
|
const int64_t* labels,
|
|
|
|
|
T* loss, T* log_softmax,
|
|
|
|
|
int feature_size,
|
|
|
|
|
int ignore_idx)
|
|
|
|
|
: logits_(logits),
|
|
|
|
|
labels_(labels),
|
|
|
|
|
loss_(loss),
|
|
|
|
|
log_softmax_(log_softmax),
|
|
|
|
|
feature_size_(feature_size),
|
|
|
|
|
ignore_idx_(ignore_idx) {}
|
|
|
|
|
|
|
|
|
|
__device__ void operator()(int idx) const {
|
|
|
|
|
auto row_idx = idx / feature_size_;
|
|
|
|
|
auto col_idx = idx % feature_size_;
|
|
|
|
|
if (col_idx != labels_[row_idx] || col_idx == ignore_idx_) {
|
|
|
|
|
log_softmax_[idx] = real_exp(log_softmax_[idx]);
|
|
|
|
|
} else {
|
|
|
|
|
auto softmax = log_softmax_[idx];
|
|
|
|
|
log_softmax_[idx] = real_exp(softmax);
|
|
|
|
|
loss_[row_idx] = -softmax;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* logits_;
|
|
|
|
|
const int64_t* labels_;
|
|
|
|
|
T* loss_;
|
|
|
|
|
T* log_softmax_;
|
|
|
|
|
int feature_size_;
|
|
|
|
|
int ignore_idx_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out,
|
|
|
|
|
int batch_size) {
|
|
|
|
|
auto idx = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
|
if (idx < batch_size) out[idx] = static_cast<T>(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void HardLabelSoftmaxWithCrossEntropy(
|
|
|
|
|
const platform::CUDADeviceContext& ctx, const T* logits_data,
|
|
|
|
|
const int64_t* labels_data, T* loss_data, T* softmax_data, int batch_size,
|
|
|
|
|
int feature_size, int ignore_idx) {
|
|
|
|
|
constexpr int kMaxBlockDim = 512;
|
|
|
|
|
int block_dim = feature_size >= kMaxBlockDim
|
|
|
|
|
? kMaxBlockDim
|
|
|
|
|
: (1 << static_cast<int>(std::log2(feature_size)));
|
|
|
|
|
auto stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
|
|
|
|
|
case BlockDim: { \
|
|
|
|
|
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
|
|
|
|
|
logits_data, loss_data, feature_size); \
|
|
|
|
|
RowReductionForDiffMaxSum<T, BlockDim, \
|
|
|
|
|
true><<<batch_size, BlockDim, 0, stream>>>( \
|
|
|
|
|
logits_data, loss_data, softmax_data, feature_size); \
|
|
|
|
|
platform::ForRange<platform::CUDADeviceContext> for_range( \
|
|
|
|
|
ctx, batch_size* feature_size); \
|
|
|
|
|
if (ignore_idx >= 0 && ignore_idx < feature_size) { \
|
|
|
|
|
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
|
|
|
|
|
logits_data, labels_data, loss_data, softmax_data, feature_size, \
|
|
|
|
|
ignore_idx)); \
|
|
|
|
|
} else { \
|
|
|
|
|
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
|
|
|
|
|
logits_data, labels_data, loss_data, softmax_data, feature_size)); \
|
|
|
|
|
} \
|
|
|
|
|
} break
|
|
|
|
|
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
|
|
|
|
|
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
|
|
|
|
|
case 1:
|
|
|
|
|
SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) /
|
|
|
|
|
kMaxBlockDim,
|
|
|
|
|
kMaxBlockDim, 0, stream>>>(
|
|
|
|
|
softmax_data, batch_size);
|
|
|
|
|
cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
|
|
|
|
|
const T* labels_data,
|
|
|
|
@ -237,7 +372,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
|
|
|
|
|
kMaxBlockDim,
|
|
|
|
|
kMaxBlockDim, 0, stream>>>(
|
|
|
|
|
softmax_data, batch_size);
|
|
|
|
|
cudaMemsetAsync(loss_data, 0, batch_size, stream);
|
|
|
|
|
cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
|
|
|
|
@ -272,11 +407,21 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
logits_data, labels_data, softmax_data, loss_data, batch_size,
|
|
|
|
|
feature_size, context.cuda_device_context().stream());
|
|
|
|
|
} else {
|
|
|
|
|
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
|
|
|
|
|
softmax);
|
|
|
|
|
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
|
|
|
|
|
context.cuda_device_context(), loss, softmax, labels, false,
|
|
|
|
|
ignore_index);
|
|
|
|
|
if (!context.Attr<bool>("numeric_stable_mode")) {
|
|
|
|
|
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
|
|
|
|
|
softmax);
|
|
|
|
|
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
|
|
|
|
|
context.cuda_device_context(), loss, softmax, labels, false,
|
|
|
|
|
ignore_index);
|
|
|
|
|
} else {
|
|
|
|
|
int batch_size = logits->dims()[0];
|
|
|
|
|
int feature_size = logits->dims()[1];
|
|
|
|
|
auto* logits_data = logits->data<T>();
|
|
|
|
|
auto* labels_data = labels->data<int64_t>();
|
|
|
|
|
HardLabelSoftmaxWithCrossEntropy<T>(
|
|
|
|
|
context.cuda_device_context(), logits_data, labels_data, loss_data,
|
|
|
|
|
softmax_data, batch_size, feature_size, ignore_index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|