|
|
|
@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same shape "
|
|
|
|
|
"except the last dimension.");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
|
|
|
|
|
framework::product(label_dims) > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same shape "
|
|
|
|
|
"except the last dimension.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto y_dims = x_dims;
|
|
|
|
|
y_dims[rank - 1] = 1;
|
|
|
|
|