|
|
|
@ -253,12 +253,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
|
|
|
|
|
public:
|
|
|
|
|
HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
|
|
|
|
|
T* log_softmax, int64_t d,
|
|
|
|
|
int axis_dim)
|
|
|
|
|
int axis_dim, int ignore_idx)
|
|
|
|
|
: labels_(labels),
|
|
|
|
|
loss_(loss),
|
|
|
|
|
log_softmax_(log_softmax),
|
|
|
|
|
d_(d),
|
|
|
|
|
axis_dim_(axis_dim) {}
|
|
|
|
|
axis_dim_(axis_dim),
|
|
|
|
|
ignore_idx_(ignore_idx) {}
|
|
|
|
|
|
|
|
|
|
__device__ void operator()(int64_t idx) const {
|
|
|
|
|
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
|
|
|
|
@ -268,10 +269,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
|
|
|
|
|
int64_t idx_remain = idx % remain;
|
|
|
|
|
// labels, loss view as [n, remain]
|
|
|
|
|
int64_t idx_lbl = idx_n * remain + idx_remain;
|
|
|
|
|
PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_,
|
|
|
|
|
"The value of label[%ld] expected >= 0 and < %ld,"
|
|
|
|
|
PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ ||
|
|
|
|
|
labels_[idx_lbl] == ignore_idx_,
|
|
|
|
|
"The value of label[%ld] expected >= 0 and < %ld, or == %d,"
|
|
|
|
|
"but got %ld. Please check input value.",
|
|
|
|
|
idx_lbl, d_, labels_[idx_lbl]);
|
|
|
|
|
idx_lbl, d_, ignore_idx_, labels_[idx_lbl]);
|
|
|
|
|
// It also would ignore labels not in range(class_num).
|
|
|
|
|
if (idx_axis != labels_[idx_lbl]) {
|
|
|
|
|
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
|
|
|
|
@ -288,6 +290,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
|
|
|
|
|
T* log_softmax_;
|
|
|
|
|
int64_t d_;
|
|
|
|
|
int axis_dim_;
|
|
|
|
|
int ignore_idx_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -355,7 +358,7 @@ static void HardLabelSoftmaxWithCrossEntropy(
|
|
|
|
|
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
|
|
|
|
|
} else { \
|
|
|
|
|
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
|
|
|
|
|
labels_data, loss_data, softmax_data, d, axis_dim)); \
|
|
|
|
|
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
|
|
|
|
|
} \
|
|
|
|
|
} break
|
|
|
|
|
|
|
|
|
|