@ -37,11 +37,17 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
const int d, const int remain) {
const int d, const int remain, const int64_t* labels,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d;
int idx_remain = index % remain;
logit_grad[index] *= loss_grad[idx_n * remain + idx_remain];
int idx_lbl = idx_n * remain + idx_remain;
if (labels[idx_lbl] == ignore_index) {
logit_grad[index] = static_cast<T>(0.);
} else {
logit_grad[index] *= loss_grad[idx_lbl];
}
}
}
@ -260,6 +266,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
int idx_remain = idx % remain;
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
// It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
} else {
@ -513,7 +520,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
int num = n * d;
grid = (num + block - 1) / block;
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
d, remain);
d, remain, label_data, ignore_index );
}
}
};