|
|
|
@ -22,28 +22,21 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
|
// input check
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
|
|
|
|
"Output(X2) shouldn't be null.");
|
|
|
|
|
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
|
|
|
|
|
auto x1_dims = ctx.Input<framework::Tensor>("X1")->dims();
|
|
|
|
|
auto x2_dims = ctx.Input<framework::Tensor>("X2")->dims();
|
|
|
|
|
PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) &&
|
|
|
|
|
(label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
|
"All inputs must be vector with the same size.");
|
|
|
|
|
auto act_t = ctx.Output<framework::LoDTensor>("Activated");
|
|
|
|
|
auto out_t = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
if (act_t) {
|
|
|
|
|
act_t->Resize(label_dims);
|
|
|
|
|
}
|
|
|
|
|
if (out_t) {
|
|
|
|
|
out_t->Resize(label_dims);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
auto x1_dims = ctx->GetInputDim("X1");
|
|
|
|
|
auto x2_dims = ctx->GetInputDim("X2");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
(label_dims == x1_dims) && (x1_dims == x2_dims) &&
|
|
|
|
|
(label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
|
"All inputs must be 2-D tensor with shape [batch_size x 1].");
|
|
|
|
|
ctx->SetOutputDim("Activated", label_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", label_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -71,7 +64,7 @@ class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(2-D tensor with shape [batch_size x 1])"
|
|
|
|
|
"The output loss of MarginRankLoss operator");
|
|
|
|
|
"The output loss of MarginRankLoss operator.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`}
|
|
|
|
@ -96,26 +89,17 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Activated"),
|
|
|
|
|
"Intermediate(Activated) shouldn't be null.");
|
|
|
|
|
auto dims = ctx.Input<framework::Tensor>("X1")->dims();
|
|
|
|
|
auto *x1_grad =
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("X1"));
|
|
|
|
|
auto *x2_grad =
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("X2"));
|
|
|
|
|
if (x1_grad) {
|
|
|
|
|
x1_grad->Resize(dims);
|
|
|
|
|
}
|
|
|
|
|
if (x2_grad) {
|
|
|
|
|
x2_grad->Resize(dims);
|
|
|
|
|
}
|
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Activated"),
|
|
|
|
|
"Intermediate(Activated) shouldn't be null.");
|
|
|
|
|
auto dims = ctx->GetInputDim("Label");
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X1"), dims);
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X2"), dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|