|
|
|
@ -23,22 +23,26 @@ class BprLossOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BprLoss");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BprLoss");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BprLoss");
|
|
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
int rank = x_dims.size();
|
|
|
|
int rank = x_dims.size();
|
|
|
|
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"Input(X) and Input(Label) shall have the same rank.");
|
|
|
|
rank, label_dims.size(),
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Input(X) and Input(Label) shall have the same rank."));
|
|
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
|
|
|
|
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
|
|
|
|
framework::product(label_dims) > 0)) {
|
|
|
|
framework::product(label_dims) > 0)) {
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
"Input(X) and Input(Label) shall have the same shape "
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
"except the last dimension.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Input(X) and Input(Label) shall have the same shape "
|
|
|
|
|
|
|
|
"except the last dimension."));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto y_dims = x_dims;
|
|
|
|
auto y_dims = x_dims;
|
|
|
|
@ -63,33 +67,41 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BprLossGradient");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BprLossGradient");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
|
|
|
|
"Input(Y@GRAD) shoudl be not null.");
|
|
|
|
framework::GradVarName("Y"), "BprLossGradient");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
|
|
|
"Output(X@GRAD) should be not null.");
|
|
|
|
framework::GradVarName("X"), "BprLossGradient");
|
|
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
|
|
|
|
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
|
|
|
|
int rank = x_dims.size();
|
|
|
|
int rank = x_dims.size();
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"Input(Y@Grad) and Input(X) should have the same rank.");
|
|
|
|
dy_dims.size(), rank,
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"Input(Label) and Input(X) should have the same rank.");
|
|
|
|
"Input(Y@Grad) and Input(X) should have the same rank."));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
label_dims.size(), rank,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Input(Label) and Input(X) should have the same rank."));
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
"The Input(X) and Input(Label) should have the same "
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"shape except the last dimension.");
|
|
|
|
"The Input(X) and Input(Label) should have the same "
|
|
|
|
|
|
|
|
"shape except the last dimension."));
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
framework::slice_ddim(dy_dims, 0, rank - 1),
|
|
|
|
framework::slice_ddim(dy_dims, 0, rank - 1),
|
|
|
|
"The Input(X) and Input(Y@Grad) should have the same "
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"shape except the last dimension.");
|
|
|
|
"The Input(X) and Input(Y@Grad) should have the same "
|
|
|
|
|
|
|
|
"shape except the last dimension."));
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
|
|
|
|
"The last dimension of Input(Y@Grad) should be 1.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The last dimension of Input(Y@Grad) should be 1."));
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
|
|
|
|
" the last dimension of Input(Label) should be 1.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
" the last dimension of Input(Label) should be 1."));
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
|
|
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|