|
|
|
@ -25,6 +25,14 @@ template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ValueClip {
|
|
|
|
|
HOSTDEVICE T operator()(const T& x) const {
|
|
|
|
|
const T kThreshold = -64.;
|
|
|
|
|
return x < kThreshold ? kThreshold : x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class SoftmaxFunctor {
|
|
|
|
|
public:
|
|
|
|
@ -47,7 +55,8 @@ class SoftmaxFunctor {
|
|
|
|
|
logits.maximum(along_class)
|
|
|
|
|
.eval()
|
|
|
|
|
.reshape(batch_by_one)
|
|
|
|
|
.broadcast(one_by_class));
|
|
|
|
|
.broadcast(one_by_class))
|
|
|
|
|
.unaryExpr(ValueClip<T>());
|
|
|
|
|
|
|
|
|
|
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
|
|
|
|
|
softmax.device(context.GetEigenDevice<Place>()) =
|
|
|
|
|