From 99c78b772a88c4534a7db7f242ce14be16f1caec Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 16 Sep 2019 10:24:34 +0800 Subject: [PATCH] fix softmax axis!=-1. test=develop (#19800) --- paddle/fluid/operators/math/softmax_impl.h | 29 ++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 4fb03cdce0..d568b186a0 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -41,6 +41,7 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { constexpr int kBatchDim = 0; constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); @@ -49,26 +50,28 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim, const int num_classes = logits.dimension(kClassDim); const int num_remain = num_classes / axis_dim; - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - Eigen::DSizes one_axis(1, axis_dim); - auto shifted_logits = (logits - - logits.maximum(along_class) + auto logits_reshape = logits.reshape(batch_axis_remain); + auto shifted_logits = (logits_reshape - + logits_reshape.maximum(along_axis) .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)) + .reshape(batch_one_remain) + .broadcast(one_axis_one)) .unaryExpr(ValueClip()); - softmax.device(*context.eigen_device()) = shifted_logits.exp(); - softmax.device(*context.eigen_device()) = (softmax * - softmax.reshape(batch_axis_remain) - .sum(along_class) + auto exp = shifted_logits.exp(); + softmax.device(*context.eigen_device()) = (exp * + exp.sum(along_axis) .inverse() .eval() - .broadcast(one_axis)); + .reshape(batch_one_remain) + .broadcast(one_axis_one)) + .reshape(batch_classes); } template