|
|
|
@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/positive_negative_pair_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -19,24 +20,19 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("Score"),
|
|
|
|
|
"Input(Score) of PositiveNegativePairOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("Label"),
|
|
|
|
|
"Input(Label) of PositiveNegativePairOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("QueryID"),
|
|
|
|
|
"Input(QueryID) of PositiveNegativePairOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("PositivePair"),
|
|
|
|
|
"Output(PositivePair) of PositiveNegativePairOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("NegativePair"),
|
|
|
|
|
"Output(NegativePair) of PositiveNegativePairOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("NeutralPair"),
|
|
|
|
|
"Output(NeutralPair) of PositiveNegativePairOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Score"), "Input", "Score",
|
|
|
|
|
"positive_negative_pair");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label",
|
|
|
|
|
"positive_negative_pair");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("QueryID"), "Input", "QueryID",
|
|
|
|
|
"positive_negative_pair");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("PositivePair"), "Output", "PositivePair",
|
|
|
|
|
"positive_negative_pair");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("NegativePair"), "Output", "NegativePair",
|
|
|
|
|
"positive_negative_pair");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("NeutralPair"), "Output", "NeutralPair",
|
|
|
|
|
"positive_negative_pair");
|
|
|
|
|
|
|
|
|
|
auto scalar_dim = framework::make_ddim({1});
|
|
|
|
|
if (ctx->HasInput("AccumulatePositivePair") ||
|
|
|
|
|
ctx->HasInput("AccumulateNegativePair") ||
|
|
|
|
@ -48,43 +44,93 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
|
|
|
|
|
"AccumulateNegativePair, AccumulateNeutralPair) of "
|
|
|
|
|
"PositiveNegativePairOp are required if one of them is "
|
|
|
|
|
"specified.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
|
|
|
|
|
"Shape of AccumulatePositivePair should be {1}.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNegativePair"), scalar_dim,
|
|
|
|
|
"Shape of AccumulateNegativePair should be {1}.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNeutralPair"), scalar_dim,
|
|
|
|
|
"Shape of AccumulateNeutralPair should be {1}.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape of Input(AccumulatePositivePair) should be [1]. Received "
|
|
|
|
|
"shape of Input(AccumulatePositivePair): [%s].",
|
|
|
|
|
ctx->GetInputDim("AccumulatePositivePair")));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("AccumulateNegativePair"), scalar_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape of Input(AccumulateNegativePair) should be [1]. Received "
|
|
|
|
|
"shape of Input(AccumulateNegativePair): [%s].",
|
|
|
|
|
ctx->GetInputDim("AccumulateNegativePair")));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("AccumulateNeutralPair"), scalar_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape of Input(AccumulateNeutralPair) should be [1]. Received "
|
|
|
|
|
"shape of Input(AccumulateNeutralPair): [%s].",
|
|
|
|
|
ctx->GetInputDim("AccumulateNeutralPair")));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto score_dim = ctx->GetInputDim("Score");
|
|
|
|
|
auto label_dim = ctx->GetInputDim("Label");
|
|
|
|
|
auto query_dim = ctx->GetInputDim("QueryID");
|
|
|
|
|
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(score_dim.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Score should be a 2-D tensor. Received shape of "
|
|
|
|
|
"Input(Score): [%s].",
|
|
|
|
|
score_dim));
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Label should be a 2-D tensor. Received shape of "
|
|
|
|
|
"Input(Label): [%s].",
|
|
|
|
|
label_dim));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() ||
|
|
|
|
|
(score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
label_dim[0], score_dim[0],
|
|
|
|
|
"Tensor Score and Label should have the same height (batch size).");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Score) and Input(Label) should have the same "
|
|
|
|
|
"height (batch size). Received: the shape of Input(Score) is "
|
|
|
|
|
"[%s], while the shape of Input(Label) is [%s]. The first "
|
|
|
|
|
"dimensions of them are different.",
|
|
|
|
|
label_dim, score_dim));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim[1], 1,
|
|
|
|
|
"The width of Label should be 1, i.e. each item should "
|
|
|
|
|
"have a scalar label.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
label_dim[1], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The width of Label should be 1, i.e. each item should "
|
|
|
|
|
"have a scalar label. Received shape of Input(Label) is [%s]. "
|
|
|
|
|
"The second dimension of it is %d, while the expected is %d.",
|
|
|
|
|
label_dim, label_dim[1], 1));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(query_dim == label_dim,
|
|
|
|
|
"QueryID should have the same shape as Label.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
query_dim, label_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(QueryID) should have the same shape as Input(Label). "
|
|
|
|
|
"Received: the shape of Input(QueryID) is [%s], "
|
|
|
|
|
"while the shape of Input(Label) is [%s].",
|
|
|
|
|
query_dim, label_dim));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("Weight")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
|
|
|
|
|
"Weight should have the same shape as Label.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("Weight"), label_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Weight) should have the same shape as Input(Label). "
|
|
|
|
|
"Received: the shape of Input(Weight) is [%s] while the shape "
|
|
|
|
|
"of Input(Label) is [%s].",
|
|
|
|
|
ctx->GetInputDim("Weight"), label_dim));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int column = ctx->Attrs().Get<int>("column");
|
|
|
|
|
auto depth = score_dim[1];
|
|
|
|
|
PADDLE_ENFORCE(column < depth && column >= -depth,
|
|
|
|
|
"Attribute column should be in the range of [-%l, %l)",
|
|
|
|
|
depth, depth);
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
column, depth,
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
"Attr(column) should be less than depth(the second "
|
|
|
|
|
"dimension of Input(Score)). Recieved Attr(column): %d, while "
|
|
|
|
|
"depth is %d.",
|
|
|
|
|
column, depth));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
column, -depth,
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
"Attr(column) should be greater than equal to negative "
|
|
|
|
|
"depth, i.e. the second dimension of Input(Score). "
|
|
|
|
|
"Recieved Attr(column): %d, while negative depth is %d.",
|
|
|
|
|
column, -depth));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("PositivePair", scalar_dim);
|
|
|
|
|