|
|
|
@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of SoftmaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"),
|
|
|
|
|
"Output(Y) of SoftmaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SoftmaxOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() == 2UL,
|
|
|
|
|
"The input of softmax op must be a matrix.");
|
|
|
|
|
ctx->SetOutputDim("Y", x_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", x_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"The input tensor of softmax. "
|
|
|
|
|
"2-D with shape [batch_size, input_feature_dimensions].");
|
|
|
|
|
AddOutput("Y", "The normalized values with the same shape as X.");
|
|
|
|
|
AddOutput("Out", "The normalized values with the same shape as X.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Softmax Operator.
|
|
|
|
|
|
|
|
|
@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax
|
|
|
|
|
operator.
|
|
|
|
|
|
|
|
|
|
For each row $i$ and each column $j$ in Input(X), we have:
|
|
|
|
|
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
|
|
|
|
|
$$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
|
|
|
|
"Input(Y@GRAD) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"),
|
|
|
|
|
ctx->GetInputDim(framework::GradVarName("Y")),
|
|
|
|
|
"Input(Y) and its gradients should have a same shape.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Out"),
|
|
|
|
|
ctx->GetInputDim(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out) and its gradients should have a same shape.");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
}
|
|
|
|
|