|
|
|
@ -36,7 +36,10 @@ static void nll_loss_1D(T* out_data, T* total_weight_data, const T* x_data,
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"label should not be out of bounds."));
|
|
|
|
|
"Label value is out of range. "
|
|
|
|
|
"Expected label value in range of [0, %d), but "
|
|
|
|
|
"received value is %d.",
|
|
|
|
|
n_classes, cur_label));
|
|
|
|
|
|
|
|
|
|
const auto cur_weight =
|
|
|
|
|
weight_data ? weight_data[cur_label] : static_cast<T>(1);
|
|
|
|
|