|
|
@ -34,8 +34,11 @@ class TopkOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
|
|
|
|
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
|
|
|
|
PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape");
|
|
|
|
PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape");
|
|
|
|
PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k,
|
|
|
|
|
|
|
|
"input must have >= k columns");
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k,
|
|
|
|
|
|
|
|
"input must have >= k columns");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
framework::DDim dims = input_dims;
|
|
|
|
framework::DDim dims = input_dims;
|
|
|
|
dims[dims.size() - 1] = k;
|
|
|
|
dims[dims.size() - 1] = k;
|
|
|
|