diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 8fbf299a7c..bd3b14775f 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -37,6 +37,13 @@ class SoftmaxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SoftmaxOp should not be null."); + auto dim_x = ctx->GetInputDim("X"); + auto rank_x = dim_x.size(); + auto axis = ctx->Attrs().Get("axis"); + PADDLE_ENFORCE(axis >= -1 && axis < rank_x, + "Attr(axis) value should larger equal then -1" + "and less then the rank of Input(X)"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -80,6 +87,10 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { "The input tensor of softmax, " "whose last dimension is the input_feature_dimensions."); AddOutput("Out", "The normalized values with the same shape as X."); + AddAttr("axis", + "The dimension of Input(x) to perform softmax," + "default -1 for last dimension") + .SetDefault(-1); AddAttr( "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 91829d5761..ad41e52116 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -13,27 +13,69 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/transpose_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +static inline void TransposeAxisToEnd(const Tensor& x, const Tensor& out, + Tensor* x_trans, Tensor* out_trans, + const int axis, std::vector perm, + const framework::ExecutionContext& ctx) { + auto dim_x = x.dims(); + int rank = dim_x.size(); + + if (axis == -1 || axis == rank - 1) { + *x_trans = x; + *out_trans = out; + return; + } + + auto& dev_ctx = ctx.template device_context(); + std::vector shape; + for (int i = 0; i < rank - 1; i++) { + if (i == axis) { + perm.push_back(rank - 1); + shape.push_back(dim_x[rank - 1]); + } else { + perm.push_back(i); + shape.push_back(dim_x[i]); + } + } + perm.push_back(axis); + shape.push_back(dim_x[axis]); + + x_trans->mutable_data(framework::make_ddim(shape), ctx.GetPlace()); + out_trans->mutable_data(framework::make_ddim(shape), ctx.GetPlace()); + TransCompute(rank, dev_ctx, x, x_trans, perm); + TransCompute(rank, dev_ctx, out, out_trans, perm); +} + template class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Out = context.Output("Out"); + const int axis = context.Attr("axis"); // allocate memory on device. Out->mutable_data(context.GetPlace()); + Tensor X_trans, Out_trans; + std::vector perm; + TransposeAxisToEnd(*X, *Out, &X_trans, &Out_trans, axis, + perm, context); + int rank = X->dims().size(); - Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1); - Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + Tensor X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); + Tensor Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); #ifdef PADDLE_ON_INFERENCE math::SoftmaxFunctor()( @@ -42,6 +84,11 @@ class SoftmaxKernel : public framework::OpKernel { math::SoftmaxFunctor()( context.template device_context(), &X_2d, &Out_2d); #endif + + if (axis != -1 && axis != rank - 1) { + auto& dev_ctx = context.template device_context(); + TransCompute(rank, dev_ctx, Out_trans, Out, perm); + } } };