|
|
|
@ -37,12 +37,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
|
|
|
|
|
"Input(Label)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
|
|
|
|
|
"The 2nd dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
|
|
|
|
|
"The 2nd dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Y", {x_dims[0], 1});
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Y");
|
|
|
|
|
}
|
|
|
|
@ -99,17 +101,20 @@ class TeacherStudentSigmoidLossGradientOp
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Y@Grad) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
|
|
|
|
|
"The 2nd dimension of Input(Y@Grad) should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[1], 1,
|
|
|
|
|
"When Attr(soft_label) == false, the 2nd dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[0], dy_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Y@Grad) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
|
|
|
|
|
"The 2nd dimension of Input(Y@Grad) should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[1], 1,
|
|
|
|
|
"When Attr(soft_label) == false, the 2nd dimension of "
|
|
|
|
|
"Input(Label) should be 1.");
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
|