|
|
|
@ -23,20 +23,18 @@ class TopkV2Op : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of TopkOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of TopkOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
|
|
|
|
|
"Output(Indices) of TopkOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "topk_v2");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "topk_v2");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "topk_v2");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
|
const int& dim_size = input_dims.size();
|
|
|
|
|
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
|
|
|
|
|
PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true,
|
|
|
|
|
"the axis of topk"
|
|
|
|
|
"must be [-%d, %d), but you set axis is %d",
|
|
|
|
|
dim_size, dim_size, axis);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
(axis < dim_size) && (axis >= (-1 * dim_size)), true,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"the axis of topk must be [-%d, %d), but you set axis is %d",
|
|
|
|
|
dim_size, dim_size, axis));
|
|
|
|
|
|
|
|
|
|
if (axis < 0) axis += dim_size;
|
|
|
|
|
|
|
|
|
@ -47,18 +45,22 @@ class TopkV2Op : public framework::OperatorWithKernel {
|
|
|
|
|
} else {
|
|
|
|
|
k = static_cast<int>(ctx->Attrs().Get<int>("k"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(k >= 1, true,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"the attribute of k in the topk must >= 1 or be a "
|
|
|
|
|
"Tensor, but received %d .",
|
|
|
|
|
k);
|
|
|
|
|
k));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(input_dims.size(), 1,
|
|
|
|
|
"input of topk must have >= 1d shape");
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"input of topk must have >= 1d shape"));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
input_dims[axis], k,
|
|
|
|
|
"input of topk op must have >= %d columns in axis of %d", k, axis);
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"input of topk op must have >= %d columns in axis of %d", k,
|
|
|
|
|
axis));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DDim dims = input_dims;
|
|
|
|
|