|
|
|
@ -19,21 +19,28 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class ReorderLoDTensorProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class ReorderLoDTensorByRankTableOpProtoMaker
|
|
|
|
|
: public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
ReorderLoDTensorProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto,
|
|
|
|
|
OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "(LoDTensor) the input lod tensor need to be reordered.");
|
|
|
|
|
AddInput("RankTable",
|
|
|
|
|
"(LoDRankTable) the rank table that input need follow");
|
|
|
|
|
AddOutput("Out", "(LoDTensor) reordered lod tensor");
|
|
|
|
|
AddComment(R"DOC(ReorderLoDTensorLoDRankTable
|
|
|
|
|
AddComment(R"DOC(ReorderLoDTensorByRankTable
|
|
|
|
|
|
|
|
|
|
Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by
|
|
|
|
|
index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of
|
|
|
|
|
X will be the first sequence of Output.
|
|
|
|
|
|
|
|
|
|
NOTE: The RankTable does not need to be calculated by X.
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
The X = [Seq0, Seq1, Seq2, Seq3]. The indices of RankTable are [3, 0, 2, 1].
|
|
|
|
|
|
|
|
|
|
The Out = [Seq3, Seq0, Seq2, Seq1] with correct LoD information.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -146,8 +153,9 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase {
|
|
|
|
|
size_t out_offset = 0;
|
|
|
|
|
out->mutable_lod()->clear();
|
|
|
|
|
for (auto &item : rank_table.items()) {
|
|
|
|
|
out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[item.index],
|
|
|
|
|
x, out, out_offset);
|
|
|
|
|
PADDLE_ENFORCE_LT(item.index, absolute_table.size());
|
|
|
|
|
out_offset = CopyTensorAndLod(dev_ctx, absolute_table[item.index], x, out,
|
|
|
|
|
out_offset);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -220,6 +228,7 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(reorder_lod_tensor_by_rank,
|
|
|
|
|
ops::ReorderLoDTensorByRankTableOp,
|
|
|
|
|
ops::ReorderLodTensorByRankGradOpMaker,
|
|
|
|
|
ops::ReorderLoDTensorProtoMaker, ops::IdentityInferShape);
|
|
|
|
|
ops::ReorderLoDTensorByRankTableOpProtoMaker,
|
|
|
|
|
ops::IdentityInferShape);
|
|
|
|
|
REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad,
|
|
|
|
|
ops::ReorderLoDTensorByRankGradOp, ops::IdentityInferShape);
|
|
|
|
|