|
|
|
@ -26,6 +26,7 @@ public:
|
|
|
|
|
void Compute(const framework::KernelContext& context) const override {
|
|
|
|
|
auto input = context.Input(0)->Get<framework::Tensor>();
|
|
|
|
|
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto logits = input.matrix<T>();
|
|
|
|
|
auto softmax = output->matrix<T>();
|
|
|
|
@ -40,19 +41,21 @@ public:
|
|
|
|
|
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
|
|
|
|
|
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
|
|
|
|
|
|
|
|
|
|
auto shifted_logits = (logits - logits.maximum(along_class)
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
auto shifted_logits = (logits -
|
|
|
|
|
logits.maximum(along_class)
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
|
|
|
|
|
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
|
|
|
|
|
|
|
|
|
|
softmax.device(*(context.GetEigenDevice<Place>())) =
|
|
|
|
|
(softmax * softmax.sum(along_class)
|
|
|
|
|
.inverse()
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
(softmax *
|
|
|
|
|
softmax.sum(along_class)
|
|
|
|
|
.inverse()
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|