|
|
@ -23,9 +23,9 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
|
|
|
|
PADDLE_ENFORCE(ctx.Input<Tensor>("logits")->dims().size() == 2UL,
|
|
|
|
"The input of softmax op must be a matrix.");
|
|
|
|
"The input of softmax op must be a matrix.");
|
|
|
|
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
ctx.Output<Tensor>("softmax")->Resize(ctx.Input<Tensor>("logits")->dims());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -34,10 +34,10 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
SoftmaxOpMaker(framework::OpProto *proto,
|
|
|
|
SoftmaxOpMaker(framework::OpProto *proto,
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("X",
|
|
|
|
AddInput("logits",
|
|
|
|
"The input tensor of softmax. "
|
|
|
|
"The input tensor of softmax. "
|
|
|
|
"2-D with shape [batch_size, input_feature_dimensions].");
|
|
|
|
"2-D with shape [batch_size, input_feature_dimensions].");
|
|
|
|
AddOutput("Y", "The normalized values with the same shape as X.");
|
|
|
|
AddOutput("softmax", "The normalized values with the same shape as X.");
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
The input of softmax operator is a 2-D tensor with shape N x K (N is the
|
|
|
|
The input of softmax operator is a 2-D tensor with shape N x K (N is the
|
|
|
|
batch_size, K is the dimension of input feature). The output tensor has the
|
|
|
|
batch_size, K is the dimension of input feature). The output tensor has the
|
|
|
@ -64,14 +64,17 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("softmax"),
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
|
|
|
|
"Input(softmax) should be not null.");
|
|
|
|
"Input(Y@GRAD) should not be null");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("softmax")),
|
|
|
|
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
|
|
|
|
"Input(softmax@GRAD) should be not null.");
|
|
|
|
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"the shape of Input(0) and Input(1) should be the same");
|
|
|
|
ctx.Input<Tensor>("softmax")->dims(),
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("X"))
|
|
|
|
ctx.Input<Tensor>(framework::GradVarName("softmax"))->dims(),
|
|
|
|
->Resize(ctx.Input<Tensor>("Y")->dims());
|
|
|
|
"Input(softmax) and its gradients should have a same shape.");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("logits"))
|
|
|
|
|
|
|
|
->Resize(ctx.Input<Tensor>("logits")->dims());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|