|
|
|
@ -27,6 +27,7 @@ public:
|
|
|
|
|
void Compute(const framework::KernelContext& context) const override {
|
|
|
|
|
auto input = context.Input(0)->Get<framework::Tensor>();
|
|
|
|
|
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto logits = framework::EigenMatrix<T>::From(input);
|
|
|
|
|
auto softmax = framework::EigenMatrix<T>::From(*output);
|
|
|
|
@ -41,19 +42,21 @@ public:
|
|
|
|
|
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
|
|
|
|
|
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
|
|
|
|
|
|
|
|
|
|
auto shifted_logits = (logits - logits.maximum(along_class)
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
auto shifted_logits = (logits -
|
|
|
|
|
logits.maximum(along_class)
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
|
|
|
|
|
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
|
|
|
|
|
|
|
|
|
|
softmax.device(*(context.GetEigenDevice<Place>())) =
|
|
|
|
|
(softmax * softmax.sum(along_class)
|
|
|
|
|
.inverse()
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
(softmax *
|
|
|
|
|
softmax.sum(along_class)
|
|
|
|
|
.inverse()
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|