Refine merge_selected_rows Doc ()

* add doc for MergeSelectedRows
test=develop

* checkout selected_rows
test=develop
revert-15207-remove_op_handle_lock_and_fix_var
chengduo 6 years ago committed by GitHub
parent 3babc80160
commit a015a8a39d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,6 +26,13 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel {
"Input(X) of MergeSelectedRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MergeSelectedRowsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(),
framework::proto::VarType::SELECTED_ROWS,
"Input X only should be SelectedRows.");
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
framework::proto::VarType::SELECTED_ROWS,
"Output Y only should be SelectedRows.");
ctx->ShareDim("X", /*->*/ "Out");
}
};
@ -43,7 +50,28 @@ class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
R"DOC(
MergeSelectedRows Operator.
MergeSelectedRows is used to merge the duplicated rows of the input.
MergeSelectedRows is used to merge the duplicated rows of the input. The
output's row has no duplicated, and it's order is incremental.
Example:
Input:
X.rows is [0, 5, 5, 4, 19]
X.height is 20
X.value is:
[[1, 1]
[2, 2]
[3, 3]
[4, 4]
[6, 6]]
Output:
Out.row is [0, 4, 5, 19]
Out.height is 20
Out.value is:
[[1, 1]
[4, 4]
[5, 5]
[6, 6]]
)DOC");
}
};

@ -29,7 +29,7 @@ class TestGetTensorFromSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
x_rows = [0, 5, 5, 4, 20]
x_rows = [0, 5, 5, 4, 19]
height = 20
row_numel = 2

@ -29,8 +29,8 @@ class TestMergeSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
x_rows = [0, 5, 5, 4, 20]
out_rows = [0, 4, 5, 20]
x_rows = [0, 5, 5, 4, 19]
out_rows = [0, 4, 5, 19]
height = 20
row_numel = 2

Loading…
Cancel
Save