|
|
|
@ -24,7 +24,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
|
|
|
|
|
"The input of softmax op must be matrix");
|
|
|
|
|
"The input of softmax op must be a matrix.");
|
|
|
|
|
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -34,9 +34,27 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
SoftmaxOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "input of softmax");
|
|
|
|
|
AddOutput("Y", "output of softmax");
|
|
|
|
|
AddComment("Softmax Op");
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"The input tensor of softmax. "
|
|
|
|
|
"2-D with shape [batch_size, input_feature_dimensions].");
|
|
|
|
|
AddOutput("Y", "The normalized values with the same shape as X.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
The input of softmax operator is a 2-D tensor with shape N x K (N is the
|
|
|
|
|
batch_size, K is the dimension of input feature). The output tensor has the
|
|
|
|
|
same shape as the input tensor.
|
|
|
|
|
|
|
|
|
|
For each row of the input tensor, the softmax operator squashes the
|
|
|
|
|
K-dimensional vector of arbitrary real values to a K-dimensional vector of real
|
|
|
|
|
values in the range [0, 1] that add up to 1. Specifically, it computes the
|
|
|
|
|
exponential of the given dimension and the sum of exponential values of all
|
|
|
|
|
the other dimensions in the K-dimensional vector input. Then the ratio of the
|
|
|
|
|
exponential of the given dimension and the sum of exponential values of all
|
|
|
|
|
the other dimensions is the output of the softmax operator.
|
|
|
|
|
|
|
|
|
|
For each row `i` and each column `j` in X, we have:
|
|
|
|
|
Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j]))
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|