|
|
|
@ -17,8 +17,6 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
class CrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -35,23 +33,21 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
|
|
|
|
|
"Input(Label)'s rank must be 2.");
|
|
|
|
|
// TODO(xinghai-sun): remove this check after swtiching to bool
|
|
|
|
|
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
|
|
|
|
|
ctx.Attr<int>("soft_label") == 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) must "
|
|
|
|
|
"be equal.");
|
|
|
|
|
if (ctx.Attr<int>("soft_label") == 1) {
|
|
|
|
|
if (ctx.Attr<bool>("soft_label")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
|
|
|
|
|
"If Attr(soft_label) == 1, The 2nd dimension of "
|
|
|
|
|
"If Attr(soft_label) == true, The 2nd dimension of "
|
|
|
|
|
"Input(X) and Input(Label) must be equal.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
|
|
|
|
|
"If Attr(soft_label) == 0, The 2nd dimension of "
|
|
|
|
|
"If Attr(soft_label) == false, The 2nd dimension of "
|
|
|
|
|
"Input(Label) must be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx.Output<LoDTensor>("Y")->Resize({x->dims()[0], 1});
|
|
|
|
|
ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
|
|
|
|
|
ctx.ShareLoD("X", /*->*/ "Y");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -74,9 +70,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
|
|
|
|
|
"Input(Label)'s rank must be 2.");
|
|
|
|
|
// TODO(xinghai-sun): remove this check after swtiching to bool
|
|
|
|
|
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
|
|
|
|
|
ctx.Attr<int>("soft_label") == 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) must "
|
|
|
|
|
"be equal.");
|
|
|
|
@ -85,17 +78,17 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
|
|
|
|
|
"The 2nd dimension of Input(Y@Grad) must be 1.");
|
|
|
|
|
if (ctx.Attr<int>("soft_label") == 1) {
|
|
|
|
|
if (ctx.Attr<bool>("soft_label")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
|
|
|
|
|
"If Attr(soft_label) == 1, The 2nd dimension of "
|
|
|
|
|
"If Attr(soft_label) == true, The 2nd dimension of "
|
|
|
|
|
"Input(X) and Input(Label) must be equal.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
|
|
|
|
|
"If Attr(soft_label) == 0, The 2nd dimension of "
|
|
|
|
|
"If Attr(soft_label) == false, The 2nd dimension of "
|
|
|
|
|
"Input(Label) must be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
|
|
|
|
|
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
dx->Resize(x->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -108,7 +101,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "The first input of CrossEntropyOp");
|
|
|
|
|
AddInput("Label", "The second input of CrossEntropyOp");
|
|
|
|
|
AddOutput("Y", "The output of CrossEntropyOp");
|
|
|
|
|
AddAttr<int>("soft_label", "Is soft label. Default zero.").SetDefault(0);
|
|
|
|
|
AddAttr<bool>("soft_label", "Is soft label. Default zero.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
CrossEntropy Operator.
|
|
|
|
@ -116,12 +110,12 @@ CrossEntropy Operator.
|
|
|
|
|
It supports both standard cross-entropy and soft-label cross-entropy loss
|
|
|
|
|
computation.
|
|
|
|
|
1) One-hot cross-entropy:
|
|
|
|
|
soft_label = 0, Label[i, 0] indicates the class index for sample i:
|
|
|
|
|
soft_label = False, Label[i, 0] indicates the class index for sample i:
|
|
|
|
|
|
|
|
|
|
Y[i] = -log(X[i, Label[i]])
|
|
|
|
|
|
|
|
|
|
2) Soft-label cross-entropy:
|
|
|
|
|
soft_label = 1, Label[i, j] indicates the soft label of class j
|
|
|
|
|
soft_label = True, Label[i, j] indicates the soft label of class j
|
|
|
|
|
for sample i:
|
|
|
|
|
|
|
|
|
|
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
|
|
|
|
@ -133,6 +127,9 @@ computation.
|
|
|
|
|
As a special case of 2), when each row of Input(Label) has only one
|
|
|
|
|
non-zero element (equals 1), soft-label cross-entropy degenerates to a
|
|
|
|
|
one-hot cross-entropy with one-hot label representation.
|
|
|
|
|
|
|
|
|
|
Both the input `X` and `Label` can carry the LoD (Level of Details) information,
|
|
|
|
|
or not. But the output only shares the LoD with input `X`.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|