|
|
|
@ -23,21 +23,21 @@ using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
|
|
|
|
|
const int64_t* labels, const int batch_size,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int sample_idx = tid / class_num;
|
|
|
|
|
|
|
|
|
|
if (tid < batch_size) {
|
|
|
|
|
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
|
|
|
|
|
logit_grad[tid * class_num + labels[tid]] -= static_cast<T>(1.);
|
|
|
|
|
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
|
|
|
|
|
const int batch_size, const int class_num) {
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
int idx = i * class_num + labels[i];
|
|
|
|
|
logit_grad[idx] -= static_cast<T>(1.);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tid < batch_size * class_num) {
|
|
|
|
|
logit_grad[tid] *= loss_grad[sample_idx];
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
logit_grad[i] *= loss_grad[i / class_num];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -94,22 +94,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int batch_size = logit_grad->dims()[0];
|
|
|
|
|
const int class_num = logit_grad->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (batch_size * class_num + block - 1) / block;
|
|
|
|
|
auto stream = context.cuda_device_context().stream();
|
|
|
|
|
|
|
|
|
|
if (context.Attr<bool>("soft_label")) {
|
|
|
|
|
int grid = (batch_size * class_num + block - 1) / block;
|
|
|
|
|
const T* label_data = labels->data<T>();
|
|
|
|
|
SoftCrossEntropyGradientKernel<
|
|
|
|
|
T><<<grid, block, 0,
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>()
|
|
|
|
|
.stream()>>>(logit_grad_data, loss_grad_data, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
|
|
|
|
|
logit_grad_data, loss_grad_data, label_data, batch_size, class_num);
|
|
|
|
|
} else {
|
|
|
|
|
int grid = (batch_size + block - 1) / block;
|
|
|
|
|
const int64_t* label_data = labels->data<int64_t>();
|
|
|
|
|
CrossEntropyGrad<
|
|
|
|
|
T><<<grid, block, 0,
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>()
|
|
|
|
|
.stream()>>>(logit_grad_data, loss_grad_data, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
|
|
|
|
|
logit_grad_data, label_data, batch_size, class_num);
|
|
|
|
|
int num = batch_size * class_num;
|
|
|
|
|
grid = (num + block - 1) / block;
|
|
|
|
|
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
|
|
|
|
|
class_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|