|
|
|
@ -28,12 +28,12 @@ class BprLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto label_Pos_dims = ctx->GetInputDim("LabelPos");
|
|
|
|
|
auto label_pos_dims = ctx->GetInputDim("LabelPos");
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, label_Pos_dims.size(),
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, label_pos_dims.size(),
|
|
|
|
|
"Input(X) and Input(LabelPos) shall have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_Pos_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_pos_dims, 0, rank - 1),
|
|
|
|
|
"Input(X) and Input(LabelPos) shall have the same shape "
|
|
|
|
|
"except the last dimension.");
|
|
|
|
|
|
|
|
|
|