|
|
|
@ -21,7 +21,7 @@ class TopkOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of TopkOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
@ -44,12 +44,25 @@ class TopkOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|
ctx->ShareLoD("X", "Indices");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
|
|
|
|
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
|
|
|
|
ctx.device_context(), layout_, library_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X", "(Tensor) The input of Topk op");
|
|
|
|
|
AddInput("K",
|
|
|
|
|
"(Tensor) Number of top elements to look for along "
|
|
|
|
|
"the last dimension (along each row for matrices).")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("Out", "(Tensor) The output tensor of Topk op");
|
|
|
|
|
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|