|
|
|
@ -39,7 +39,7 @@ class RankLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
|
|
|
|
|
"All inputs must have the same size");
|
|
|
|
|
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
|
"All inputs must be row vector with size batch_sizex1.");
|
|
|
|
|
"All inputs must be row vector with size batch_size x 1.");
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|