|
|
@ -268,6 +268,10 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
|
|
|
|
int64_t idx_remain = idx % remain;
|
|
|
|
int64_t idx_remain = idx % remain;
|
|
|
|
// labels, loss view as [n, remain]
|
|
|
|
// labels, loss view as [n, remain]
|
|
|
|
int64_t idx_lbl = idx_n * remain + idx_remain;
|
|
|
|
int64_t idx_lbl = idx_n * remain + idx_remain;
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_,
|
|
|
|
|
|
|
|
"The value of label[%ld] expected >= 0 and < %ld,"
|
|
|
|
|
|
|
|
"but got %ld. Please check input value.",
|
|
|
|
|
|
|
|
idx_lbl, d_, labels_[idx_lbl]);
|
|
|
|
// It also would ignore labels not in range(class_num).
|
|
|
|
// It also would ignore labels not in range(class_num).
|
|
|
|
if (idx_axis != labels_[idx_lbl]) {
|
|
|
|
if (idx_axis != labels_[idx_lbl]) {
|
|
|
|
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
|
|
|
|
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
|
|
|
|