|
|
|
@ -26,6 +26,13 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Input(X) of MergeSelectedRowsOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of MergeSelectedRowsOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(),
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS,
|
|
|
|
|
"Input X only should be SelectedRows.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS,
|
|
|
|
|
"Output Y only should be SelectedRows.");
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -43,7 +50,28 @@ class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
R"DOC(
|
|
|
|
|
MergeSelectedRows Operator.
|
|
|
|
|
|
|
|
|
|
MergeSelectedRows is used to merge the duplicated rows of the input.
|
|
|
|
|
MergeSelectedRows is used to merge the duplicated rows of the input. The
|
|
|
|
|
output's row has no duplicated, and it's order is incremental.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
Input:
|
|
|
|
|
X.rows is [0, 5, 5, 4, 19]
|
|
|
|
|
X.height is 20
|
|
|
|
|
X.value is:
|
|
|
|
|
[[1, 1]
|
|
|
|
|
[2, 2]
|
|
|
|
|
[3, 3]
|
|
|
|
|
[4, 4]
|
|
|
|
|
[6, 6]]
|
|
|
|
|
|
|
|
|
|
Output:
|
|
|
|
|
Out.row is [0, 4, 5, 19]
|
|
|
|
|
Out.height is 20
|
|
|
|
|
Out.value is:
|
|
|
|
|
[[1, 1]
|
|
|
|
|
[4, 4]
|
|
|
|
|
[5, 5]
|
|
|
|
|
[6, 6]]
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|