From 1cca7114c647a3a8833eeb2ae981e4392ecb347d Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Tue, 16 Apr 2019 10:08:05 +0800 Subject: [PATCH] fix infer test=develop --- .../teacher_student_sigmoid_loss_op.cc | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index 6a4bea9437..7f95d16f09 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -37,12 +37,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(label_dims[1], 1UL, - "The 2nd dimension of " - "Input(Label) should be 1."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ(label_dims[1], 1UL, + "The 2nd dimension of " + "Input(Label) should be 1."); + } ctx->SetOutputDim("Y", {x_dims[0], 1}); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -99,17 +101,20 @@ class TeacherStudentSigmoidLossGradientOp PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], - "The 1st dimension of Input(X) and Input(Y@Grad) should " - "be equal."); - PADDLE_ENFORCE_EQ(dy_dims[1], 1, - "The 2nd dimension of Input(Y@Grad) should be 1."); - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "When Attr(soft_label) == false, the 2nd dimension of " - "Input(Label) should be 1."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ( + x_dims[0], dy_dims[0], + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal."); + PADDLE_ENFORCE_EQ(dy_dims[1], 1, + "The 2nd dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_dims[1], 1, + "When Attr(soft_label) == false, the 2nd dimension of " + "Input(Label) should be 1."); + } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->ShareLoD("X", framework::GradVarName("X")); }