|
|
|
@ -37,10 +37,10 @@ class ArgsortOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X) "
|
|
|
|
|
"dimension %d.",
|
|
|
|
|
axis, num_dims);
|
|
|
|
|
PADDLE_ENFORCE(axis >= 0 || axis == -1,
|
|
|
|
|
"Attr(axis) %d of ArgsortOp must be nonnegative or equal to "
|
|
|
|
|
"-1.",
|
|
|
|
|
axis);
|
|
|
|
|
PADDLE_ENFORCE(in_dims.size() + axis >= 0,
|
|
|
|
|
"Attr(axis) %d of ArgsortOp plus the number of Input(X)'s "
|
|
|
|
|
"dimensions %d must be nonnegative.",
|
|
|
|
|
axis, in_dims.size());
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", in_dims);
|
|
|
|
|
ctx->SetOutputDim("Indices", in_dims);
|
|
|
|
@ -53,9 +53,12 @@ class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X", "(Tensor) The input of Argsort op.");
|
|
|
|
|
AddOutput("Out", "(Tensor) The sorted tensor of Argsort op.");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) The sorted tensor of Argsort op, with the same "
|
|
|
|
|
"shape as Input(X).");
|
|
|
|
|
AddOutput("Indices",
|
|
|
|
|
"(Tensor) The indices of a tensor giving the sorted order.");
|
|
|
|
|
"(Tensor) The indices of a tensor giving the sorted order, with "
|
|
|
|
|
"the same shape as Input(X).");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Argsort operator
|
|
|
|
|
|
|
|
|
@ -66,8 +69,9 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
AddAttr<int>("axis",
|
|
|
|
|
"(int, default -1) The axis along which to sort the tensor, "
|
|
|
|
|
"default -1, the last dimension.")
|
|
|
|
|
"(int, default -1) The axis along which to sort the tensor. "
|
|
|
|
|
"When axis < 0, the actual axis will be the |axis|'th "
|
|
|
|
|
"counting backwards. Default -1, the last dimension.")
|
|
|
|
|
.SetDefault(-1);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|