Check label range in cross entropy calculation. (#10954)

release/0.13.0
qingqing01 7 years ago committed by GitHub
parent 0c44efb855
commit 3ba75d4a69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -46,7 +46,10 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
int lbl = label_data[i];
PADDLE_ENFORCE_GE(lbl, 0);
PADDLE_ENFORCE_LT(lbl, class_num);
int index = i * class_num + lbl;
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
}
}

Loading…
Cancel
Save