|
|
|
@ -27,20 +27,55 @@ class RankLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
// input check
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
|
|
|
|
|
"Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true,
|
|
|
|
|
"Input(Left) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true,
|
|
|
|
|
"Input(Right) shouldn't be null.");
|
|
|
|
|
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
auto left_dims = ctx->GetInputDim("Left");
|
|
|
|
|
auto right_dims = ctx->GetInputDim("Right");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
|
|
|
|
|
"All inputs must have the same size.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
(label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
|
"All inputs must be 2-D tensors with shape [batch_size x 1].");
|
|
|
|
|
// check label_dims valid
|
|
|
|
|
PADDLE_ENFORCE_GE(label_dims.size(), 1,
|
|
|
|
|
"The dimension size of Input(Label) must be greater than "
|
|
|
|
|
"or equal to 1.");
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
label_dims.size(), 2,
|
|
|
|
|
"The dimension size of Input(Label) must be less than or equal to 2.");
|
|
|
|
|
if (label_dims.size() == 2U) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[1], 1,
|
|
|
|
|
"The last dimension of Input(Label) must be 1.");
|
|
|
|
|
}
|
|
|
|
|
// check left_dims valid
|
|
|
|
|
PADDLE_ENFORCE_GE(left_dims.size(), 1,
|
|
|
|
|
"The dimension size of Input(Left) must be greater than "
|
|
|
|
|
"or equal to 1.");
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
left_dims.size(), 2,
|
|
|
|
|
"The dimension size of Input(Left) must be less than or equal to 2.");
|
|
|
|
|
if (left_dims.size() == 2U) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(left_dims[1], 1,
|
|
|
|
|
"The last dimension of Input(Left) must be 1.");
|
|
|
|
|
}
|
|
|
|
|
// check right_dims valid
|
|
|
|
|
PADDLE_ENFORCE_GE(right_dims.size(), 1,
|
|
|
|
|
"The dimension size of Input(Right) must be greater than "
|
|
|
|
|
"or equal to 1.");
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
right_dims.size(), 2,
|
|
|
|
|
"The dimension size of Input(Right) must be less than or equal to 2.");
|
|
|
|
|
if (right_dims.size() == 2U) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(right_dims[1], 1,
|
|
|
|
|
"The last dimension of Input(Right) must be 1.");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[0], left_dims[0],
|
|
|
|
|
"The first dimension of Input(Label) and Input(Left) "
|
|
|
|
|
"must have the same value.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[0], right_dims[0],
|
|
|
|
|
"The first dimension of Input(Label) and Input(Right) "
|
|
|
|
|
"must have the same value.");
|
|
|
|
|
ctx->SetOutputDim("Out", label_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -98,21 +133,25 @@ class RankLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
auto dims = ctx->GetInputDim("Left");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
|
|
|
|
|
"Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true,
|
|
|
|
|
"Input(Left) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true,
|
|
|
|
|
"Input(Right) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
auto left_dims = ctx->GetInputDim("Left");
|
|
|
|
|
auto right_dims = ctx->GetInputDim("Right");
|
|
|
|
|
auto left_grad_name = framework::GradVarName("Left");
|
|
|
|
|
auto right_grad_name = framework::GradVarName("Right");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(left_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(left_grad_name, dims);
|
|
|
|
|
ctx->SetOutputDim(left_grad_name, left_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(right_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(right_grad_name, dims);
|
|
|
|
|
ctx->SetOutputDim(right_grad_name, right_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|