|
|
|
@ -23,16 +23,32 @@ class SoftmaxWithCrossEntropyOpMaker
|
|
|
|
|
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
//(TODO caoying) replace int with boolean
|
|
|
|
|
AddAttr<int>("soft_label",
|
|
|
|
|
"(int, default 0), A flag to indicate whether to interpretate "
|
|
|
|
|
"the given labels as soft labels.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddInput("Logits",
|
|
|
|
|
"The unscaled log probabilities which is a 2-D tensor<float> with"
|
|
|
|
|
"shape [N x K]. N is the batch_size, and K is the class number.")
|
|
|
|
|
"(Tensor, default Tensor<float>), The unscaled log probabilities "
|
|
|
|
|
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
|
|
|
|
|
"and K is the class number.")
|
|
|
|
|
.NotInGradient();
|
|
|
|
|
AddInput("Label", "The ground truth. A 1-D tensor<int> with shape N.");
|
|
|
|
|
AddOutput("Softmax",
|
|
|
|
|
"Store the outputs of softmax function, "
|
|
|
|
|
"which will be used in backward calculation.")
|
|
|
|
|
AddInput(
|
|
|
|
|
"Label",
|
|
|
|
|
"(Tensor, default Tensor<int>), The ground truth which is "
|
|
|
|
|
"a 1-D or 2-D tensor. "
|
|
|
|
|
"If soft_label is set to 0, Label is a Tensor<int> with shape [N x 1]. "
|
|
|
|
|
"If soft_label is set to 1, Label is a Tensor<float/double> "
|
|
|
|
|
"with shape [N x K].");
|
|
|
|
|
AddOutput(
|
|
|
|
|
"Softmax",
|
|
|
|
|
"(Tensor, default Tensor<float>), A 2-D tensor with shape [N x K]. "
|
|
|
|
|
"The outputs value of softmax activation by given the input batch, "
|
|
|
|
|
"which will be used in backward calculation.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("Out", "A 1-D tensor<float> with shape N.");
|
|
|
|
|
AddOutput("Loss",
|
|
|
|
|
"(Tensor, default Tensor<float>), A 1-D tensor. The cross "
|
|
|
|
|
"entropy loss with shape [N x 1].");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Cross entropy loss with softmax are used as the output layer extensively. This
|
|
|
|
|
operator computes the softmax normalized values for each row of the input
|
|
|
|
@ -46,25 +62,18 @@ which will produce incorrect results.
|
|
|
|
|
This operators expects mutually exclusive hard labels, each sample in a batch
|
|
|
|
|
is in exactly one class with probabilities 1. Each sample in the batch with one
|
|
|
|
|
and only one label.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
Equation:
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@Grad) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
|
|
|
|
|
"Input(Softmax) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Lable) should be not null.");
|
|
|
|
|
1) hard label (one-hot label)
|
|
|
|
|
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
|
|
|
|
|
->Resize(ctx.Input<Tensor>("Softmax")->dims());
|
|
|
|
|
Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
|
|
|
|
|
|
|
|
|
|
2) soft label (a distribution over all classes)
|
|
|
|
|
|
|
|
|
|
Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -82,7 +91,25 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
"The label should be a 1-d tensor.");
|
|
|
|
|
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Softmax")->Resize(logits->dims());
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Out")->Resize({logits->dims()[0], 1});
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Loss")->Resize({logits->dims()[0], 1});
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
|
|
|
|
|
"Input(Loss@Grad) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
|
|
|
|
|
"Input(Softmax) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Lable) should be not null.");
|
|
|
|
|
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
|
|
|
|
|
->Resize(ctx.Input<Tensor>("Softmax")->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|