|
|
|
@ -24,25 +24,78 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void CrossEntropyKernel(T* out, const T* softmax_out,
|
|
|
|
|
const int* label, const int batch_size,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
__global__ void CrossEntropy(T* out, const T* softmax_out, const int* labels,
|
|
|
|
|
const int batch_size, const int class_num) {
|
|
|
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (i < batch_size) {
|
|
|
|
|
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
|
|
|
|
|
out[i] = -tolerable_value(std::log(softmax_out[i * class_num + label[i]]));
|
|
|
|
|
PADDLE_ASSERT(labels[i] >= 0 && labels[i] < class_num);
|
|
|
|
|
out[i] =
|
|
|
|
|
-TolerableValue<T>()(std::log(softmax_out[i * class_num + labels[i]]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out,
|
|
|
|
|
const int* label,
|
|
|
|
|
const int batch_size,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (i < batch_size) {
|
|
|
|
|
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
|
|
|
|
|
softmax_out[i * class_num + label[i]] -= 1.;
|
|
|
|
|
__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
|
|
|
|
|
const int* labels, const int batch_size,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int sample_idx = tid / class_num;
|
|
|
|
|
|
|
|
|
|
if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tid < batch_size) {
|
|
|
|
|
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
|
|
|
|
|
out_grad[tid * class_num + labels[tid]] -= 1.;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ __forceinline__ T sum_single_warp(T val) {
|
|
|
|
|
val += __shfl_down(val, 16);
|
|
|
|
|
val += __shfl_down(val, 8);
|
|
|
|
|
val += __shfl_down(val, 4);
|
|
|
|
|
val += __shfl_down(val, 2);
|
|
|
|
|
val += __shfl_down(val, 1);
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
extern __shared__ T d_sum[];
|
|
|
|
|
d_sum[tid] = 0;
|
|
|
|
|
|
|
|
|
|
int cur_idx = tid;
|
|
|
|
|
int next_idx = blockIdx.x * class_num + tid;
|
|
|
|
|
while (cur_idx < class_num) {
|
|
|
|
|
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
|
|
|
|
|
next_idx += blockDim.x;
|
|
|
|
|
cur_idx += blockDim.x;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
|
|
|
|
|
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T val = d_sum[tid];
|
|
|
|
|
val = sum_single_warp<T>(val);
|
|
|
|
|
if (tid == 0) Y[blockIdx.x] = -val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
|
|
|
|
|
const T* loss_grad,
|
|
|
|
|
const T* labels,
|
|
|
|
|
const int batch_size,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int ids = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (ids < batch_size * class_num) {
|
|
|
|
|
int row_ids = ids / class_num;
|
|
|
|
|
logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -52,27 +105,36 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
T* loss_data =
|
|
|
|
|
context.Output<Tensor>("Loss")->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
// Calculate ths softmax outputs.
|
|
|
|
|
const Tensor* logits = context.Input<Tensor>("Logits");
|
|
|
|
|
Tensor* softmax = context.Output<Tensor>("Softmax");
|
|
|
|
|
softmax->mutable_data<T>(context.GetPlace());
|
|
|
|
|
math::SoftmaxFunctor<platform::GPUPlace, T>()(logits, softmax, context);
|
|
|
|
|
T* softmax_out = softmax->data<T>();
|
|
|
|
|
|
|
|
|
|
// Calculate the cross entropy loss based on hard labels.
|
|
|
|
|
const int* label_data = context.Input<Tensor>("Label")->data<int>();
|
|
|
|
|
Tensor* loss = context.Output<Tensor>("Loss");
|
|
|
|
|
loss->mutable_data<T>(context.GetPlace());
|
|
|
|
|
T* loss_data = loss->data<T>();
|
|
|
|
|
T* softmax_out = softmax->mutable_data<T>(context.GetPlace());
|
|
|
|
|
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
|
|
|
|
|
|
|
|
|
|
const int batch_size = logits->dims()[0];
|
|
|
|
|
const int class_num = logits->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (batch_size + block - 1) / block;
|
|
|
|
|
|
|
|
|
|
CrossEntropyKernel<T><<<grid, block>>>(loss_data, softmax_out, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
if (context.Attr<bool>("softLabel")) {
|
|
|
|
|
const T* label_data = context.Input<Tensor>("Label")->data<T>();
|
|
|
|
|
block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
|
|
|
|
|
|
|
|
|
|
SoftCrossEntropyKernel<
|
|
|
|
|
T><<<batch_size, block, block * sizeof(T),
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream()>>>(loss_data, softmax_out, label_data, class_num);
|
|
|
|
|
} else {
|
|
|
|
|
const int* label_data = context.Input<Tensor>("Label")->data<int>();
|
|
|
|
|
CrossEntropy<T><<<grid, block, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream()>>>(loss_data, softmax_out, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -82,7 +144,9 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
|
|
|
|
|
const Tensor* labels = context.Input<Tensor>("Label");
|
|
|
|
|
const T* loss_grad_data =
|
|
|
|
|
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
|
|
|
|
|
Tensor* logit_grad =
|
|
|
|
|
context.Output<Tensor>(framework::GradVarName("Logits"));
|
|
|
|
|
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
|
|
|
|
@ -90,14 +154,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
const int batch_size = logit_grad->dims()[0];
|
|
|
|
|
const int class_num = logit_grad->dims()[1];
|
|
|
|
|
|
|
|
|
|
const int* label_data = context.Input<Tensor>("Label")->data<int>();
|
|
|
|
|
|
|
|
|
|
const int block = 512;
|
|
|
|
|
const int grid = (batch_size + block - 1) / block;
|
|
|
|
|
|
|
|
|
|
CrossEntropyWithSoftmaxGradKernel<T><<<grid, block>>>(
|
|
|
|
|
logit_grad_data, label_data, batch_size, class_num);
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (batch_size * class_num + block - 1) / block;
|
|
|
|
|
|
|
|
|
|
if (context.Attr<bool>("softLabel")) {
|
|
|
|
|
const T* label_data = labels->data<T>();
|
|
|
|
|
SoftCrossEntropyGradientKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream()>>>(logit_grad_data, loss_grad_data,
|
|
|
|
|
label_data, batch_size, class_num);
|
|
|
|
|
} else {
|
|
|
|
|
const int* label_data = labels->data<int>();
|
|
|
|
|
CrossEntropyGrad<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream()>>>(logit_grad_data, loss_grad_data,
|
|
|
|
|
label_data, batch_size, class_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|