|
|
|
@ -106,24 +106,40 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
auto logits_dims = ctx->GetInputDim("Logits");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Label");
|
|
|
|
|
|
|
|
|
|
int rank = logits_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
logits_dims.size(), 2UL,
|
|
|
|
|
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
|
|
|
|
|
"The labels should be a 2-D tensor.");
|
|
|
|
|
rank, labels_dims.size(),
|
|
|
|
|
"Input(logits) and Input(Label) shall have the same rank.");
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) && (framework::product(logits_dims) <= 0 ||
|
|
|
|
|
framework::product(labels_dims) <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(labels_dims, 0, rank - 1),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same shape "
|
|
|
|
|
"except the last dimension.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->Attrs().Get<bool>("soft_label")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
|
|
|
|
|
"If Attr(soft_label) == true, the 2nd dimension of "
|
|
|
|
|
"Input(X) and Input(Label) should be equal.");
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1],
|
|
|
|
|
"If Attr(soft_label) == true, the last dimension of "
|
|
|
|
|
"Input(X) and Input(Label) should be equal.");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
|
|
|
|
|
"If Attr(soft_label) == false, the 2nd dimension of "
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
|
|
|
|
|
"If Attr(softLabel) == false, the last dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Softmax", logits_dims);
|
|
|
|
|
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
|
|
|
|
|
auto loss_dims = logits_dims;
|
|
|
|
|
loss_dims[rank - 1] = 1;
|
|
|
|
|
ctx->SetOutputDim("Loss", loss_dims);
|
|
|
|
|
// ctx->SetOutputDim("Loss", {logits_dims[0], 1});
|
|
|
|
|
|
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Softmax");
|
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Loss");
|
|
|
|
@ -152,16 +168,33 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
auto softmax_dims = ctx->GetInputDim("Softmax");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
|
|
|
|
|
"The labels should be a 2-D tensor.");
|
|
|
|
|
|
|
|
|
|
int rank = softmax_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rank, labels_dims.size(),
|
|
|
|
|
"Input(logits) and Input(Label) shall have the same rank.");
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 ||
|
|
|
|
|
framework::product(labels_dims) <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::slice_ddim(softmax_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(labels_dims, 0, rank - 1),
|
|
|
|
|
"Input(Softmax) and Input(Label) shall have the same shape "
|
|
|
|
|
"except the last dimension.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->Attrs().Get<bool>("soft_label")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
|
|
|
|
|
"When Attr(soft_label) == true, the 2nd dimension of "
|
|
|
|
|
"Input(X) and Input(Label) should be equal.");
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1],
|
|
|
|
|
"If Attr(soft_label) == true, the last dimension of "
|
|
|
|
|
"Input( Softmax) and Input(Label) should be equal.");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
|
|
|
|
|
"When Attr(soft_label) == false, the 2nd dimension of "
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
|
|
|
|
|
"If Attr(softLabel) == false, the last dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|