|
|
|
@ -18,7 +18,24 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
static const float kCrossEntropyLogThreshold{1e-20};
|
|
|
|
|
template <typename T>
|
|
|
|
|
T tolerable_value(T x) {
|
|
|
|
|
static_assert(std::is_floating_point<T>::value,
|
|
|
|
|
"tolerable_value works only on float, "
|
|
|
|
|
"double and double double.");
|
|
|
|
|
|
|
|
|
|
const T kApproInf = 1e20;
|
|
|
|
|
|
|
|
|
|
if (x == INFINITY) {
|
|
|
|
|
return kApproInf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (x == -INFINITY) {
|
|
|
|
|
return -kApproInf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class OnehotCrossEntropyOpKernel : public OpKernel {
|
|
|
|
@ -36,10 +53,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
|
|
|
|
|
int batch_size = X->dims()[0];
|
|
|
|
|
int class_num = X->dims()[1];
|
|
|
|
|
|
|
|
|
|
// Y[i] = -log(X[i][j])
|
|
|
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
|
|
|
Ydata[i] = -std::log(std::max(Xdata[i * class_num + label_data[i]],
|
|
|
|
|
kCrossEntropyLogThreshold));
|
|
|
|
|
int index = i * class_num + label_data[i];
|
|
|
|
|
Ydata[i] = -tolerable_value(std::log(Xdata[index]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -62,9 +78,8 @@ class OnehotCrossEntropyGradientOpKernel : public OpKernel {
|
|
|
|
|
const int class_num = X->dims()[1];
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
|
|
|
dXdata[i * class_num + label_data[i]] =
|
|
|
|
|
-dYdata[i] / std::max(Xdata[i * class_num + label_data[i]],
|
|
|
|
|
kCrossEntropyLogThreshold);
|
|
|
|
|
int index = i * class_num + label_data[i];
|
|
|
|
|
dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|