|
|
|
@ -25,19 +25,21 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
|
|
|
|
|
"Input(Label) should be not null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
|
|
|
|
|
"Output(Y) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same rank.");
|
|
|
|
|
|
|
|
|
|
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
|
|
|
|
|
framework::contain_unknown_dim(label_dims);
|
|
|
|
|
bool check = ctx->IsRuntime() || !contain_unknown_dim;
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
@ -46,19 +48,30 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsSoftLabel(ctx)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rank, label_dims.size(),
|
|
|
|
|
"If Attr(soft_label) == true, Input(X) and Input(Label) "
|
|
|
|
|
"shall have the same rank.");
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
|
|
|
|
|
"If Attr(soft_label) == true, the last dimension of "
|
|
|
|
|
"Input(X) and Input(Label) should be equal.");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (rank == label_dims.size()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
|
|
|
|
|
"If Attr(softLabel) == false, the last dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
"the last dimension of Input(Label) should be 1.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rank, label_dims.size() + 1,
|
|
|
|
|
"The rank of Input(X) should be equal to Input(Label) plus 1.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto y_dims = x_dims;
|
|
|
|
|
auto y_dims = label_dims;
|
|
|
|
|
if (rank == label_dims.size()) {
|
|
|
|
|
y_dims[rank - 1] = 1;
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Y", y_dims);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Y");
|
|
|
|
|
}
|
|
|
|
@ -82,20 +95,19 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
|
|
|
|
|
"Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Y")), true,
|
|
|
|
|
"Input(Y@GRAD) shoudl be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|
"Output(X@GRAD) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = GetXDim(ctx);
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
|
|
|
|
|
"Input(Y@Grad) and Input(X) should have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
|
|
|
|
|
"Input(Label) and Input(X) should have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims.size(), label_dims.size(),
|
|
|
|
|
"Input(Y@Grad) and Input(Y) should have the same rank.");
|
|
|
|
|
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
|
|
|
|
@ -104,30 +116,12 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
|
"The Input(X) and Input(Label) should have the same "
|
|
|
|
|
"shape except the last dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(dy_dims, 0, rank - 1),
|
|
|
|
|
"The Input(X) and Input(Y@Grad) should have the same "
|
|
|
|
|
"shape except the last dimension.");
|
|
|
|
|
}
|
|
|
|
|
if (IsSoftLabel(ctx)) {
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[rank - 1], label_dims[rank - 1],
|
|
|
|
|
"When Attr(soft_label) == true, the last dimension of "
|
|
|
|
|
"Input(X) and Input(Label) should be equal.");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
|
|
|
|
|
"When Attr(soft_label) == false, the last dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
|
|
|
|
|
"The last dimension of Input(Y@Grad) should be 1.");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
@ -231,7 +225,7 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
|
|
|
|
|
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should be not null.");
|
|
|
|
|
CrossEntropyGradientOpBase::InferShape(ctx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -260,10 +254,10 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
CrossEntropyOpBase::InferShape(ctx);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
|
|
|
|
|
"Output(XShape) should be not null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("MatchX"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("MatchX"), true,
|
|
|
|
|
"Output(MatchX) should be not null.");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto x_dims_vec = framework::vectorize(x_dims);
|
|
|
|
@ -284,7 +278,8 @@ class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
|
|
|
|
|
public:
|
|
|
|
|
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("MatchX"), true,
|
|
|
|
|
"Input(MatchX) must exist");
|
|
|
|
|
CrossEntropyGradientOpBase::InferShape(ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|