|
|
|
@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto query_dim = ctx->GetInputDim("QueryID");
|
|
|
|
|
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
label_dim[0], score_dim[0],
|
|
|
|
|
"Tensor Score and Label should have the same height (batch size).");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim[1], 1,
|
|
|
|
|
"The width of Label should be 1, i.e. each item should "
|
|
|
|
|
"have a scalar label.");
|
|
|
|
|
PADDLE_ENFORCE(query_dim == label_dim,
|
|
|
|
|
"QueryID should have the same shape as Label.");
|
|
|
|
|
if (ctx->HasInput("Weight")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
|
|
|
|
|
"Weight should have the same shape as Label.");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() ||
|
|
|
|
|
(score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
label_dim[0], score_dim[0],
|
|
|
|
|
"Tensor Score and Label should have the same height (batch size).");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim[1], 1,
|
|
|
|
|
"The width of Label should be 1, i.e. each item should "
|
|
|
|
|
"have a scalar label.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(query_dim == label_dim,
|
|
|
|
|
"QueryID should have the same shape as Label.");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("Weight")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
|
|
|
|
|
"Weight should have the same shape as Label.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int column = ctx->Attrs().Get<int>("column");
|
|
|
|
|
auto depth = score_dim[1];
|
|
|
|
|
PADDLE_ENFORCE(column < depth && column >= -depth,
|
|
|
|
|
"Attribute column should be in the range of [-%l, %l)",
|
|
|
|
|
depth, depth);
|
|
|
|
|
}
|
|
|
|
|
int column = ctx->Attrs().Get<int>("column");
|
|
|
|
|
auto depth = score_dim[1];
|
|
|
|
|
PADDLE_ENFORCE(column < depth && column >= -depth,
|
|
|
|
|
"Attribute column should be in the range of [-%l, %l)",
|
|
|
|
|
depth, depth);
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("PositivePair", scalar_dim);
|
|
|
|
|
ctx->SetOutputDim("NegativePair", scalar_dim);
|
|
|
|
|