|
|
|
@ -82,40 +82,38 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"),
|
|
|
|
|
"Input(Logits) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Label) should be not null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"),
|
|
|
|
|
"Output(Softmax) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
|
|
|
|
|
"Output(Loss) should be not null.");
|
|
|
|
|
|
|
|
|
|
const Tensor* logits = ctx.Input<Tensor>("Logits");
|
|
|
|
|
const Tensor* labels = ctx.Input<Tensor>("Label");
|
|
|
|
|
void InferShape(framework::InferShapeContextBase* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
|
|
|
|
"Input(Logits) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
|
|
|
|
|
"Output(Softmax) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto logits_dims = ctx->GetInputDim("Logits");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
logits->dims().size(), 2UL,
|
|
|
|
|
logits_dims.size(), 2UL,
|
|
|
|
|
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
|
|
|
|
|
"The labels should be a 2-D tensor.");
|
|
|
|
|
|
|
|
|
|
if (ctx.Attr<bool>("softLabel")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1],
|
|
|
|
|
if (ctx->Attrs().Get<bool>("softLabel")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
|
|
|
|
|
"If Attr(softLabel) == true, the 2nd dimension of "
|
|
|
|
|
"Input(X) and Input(Label) should be equal.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
|
|
|
|
|
"If Attr(softLabel) == false, the 2nd dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims());
|
|
|
|
|
ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1});
|
|
|
|
|
ctx->SetOutputDim("Softmax", logits_dims);
|
|
|
|
|
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
|
|
|
|
|
|
|
|
|
|
ctx.ShareLoD("Logits", /*->*/ "Softmax");
|
|
|
|
|
ctx.ShareLoD("Logits", /*->*/ "Loss");
|
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Softmax");
|
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Loss");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -124,33 +122,32 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
|
|
|
|
|
"Input(Loss@Grad) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
|
|
|
|
|
"Input(Softmax) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
|
|
|
|
|
"Output(Logits@Grad) should be not null.");
|
|
|
|
|
|
|
|
|
|
const Tensor* softmax = ctx.Input<Tensor>("Softmax");
|
|
|
|
|
const Tensor* labels = ctx.Input<Tensor>("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
|
|
|
|
|
void InferShape(framework::InferShapeContextBase* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
|
|
|
|
|
"Input(Loss@Grad) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Softmax"),
|
|
|
|
|
"Input(Softmax) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
|
|
|
|
|
"Output(Logits@Grad) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto softmax_dims = ctx->GetInputDim("Softmax");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
|
|
|
|
|
"The labels should be a 2-D tensor.");
|
|
|
|
|
|
|
|
|
|
if (ctx.Attr<bool>("softLabel")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1],
|
|
|
|
|
if (ctx->Attrs().Get<bool>("softLabel")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
|
|
|
|
|
"When Attr(softLabel) == true, the 2nd dimension of "
|
|
|
|
|
"Input(X) and Input(Label) should be equal.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
|
|
|
|
|
"When Attr(softLabel) == false, the 2nd dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
|
|
|
|
|
->Resize(ctx.Input<Tensor>("Softmax")->dims());
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Logits"),
|
|
|
|
|
ctx->GetInputDim("Softmax"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|