|
|
@ -31,7 +31,7 @@ class TopkOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
PADDLE_ENFORCE_EQ(input_dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(input_dims.size(), 2,
|
|
|
|
"Rank of TopK op's input must be 2.");
|
|
|
|
"Rank of TopK op's input must be 2.");
|
|
|
|
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
|
|
|
|
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
|
|
|
|
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
|
|
|
|