|
|
|
@ -45,19 +45,21 @@ class SoftmaxKernel : public framework::OpKernel {
|
|
|
|
|
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
|
|
|
|
|
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
|
|
|
|
|
|
|
|
|
|
auto shifted_logits = (logits - logits.maximum(along_class)
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
auto shifted_logits = (logits -
|
|
|
|
|
logits.maximum(along_class)
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
|
|
|
|
|
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
|
|
|
|
|
|
|
|
|
|
softmax.device(context.GetEigenDevice<Place>()) =
|
|
|
|
|
(softmax * softmax.sum(along_class)
|
|
|
|
|
.inverse()
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
(softmax *
|
|
|
|
|
softmax.sum(along_class)
|
|
|
|
|
.inverse()
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|