|
|
|
@ -68,6 +68,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const auto &ids_rows = ids_selected_rows->rows();
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
|
|
const size_t shard_num = outs.size();
|
|
|
|
|
for (auto &out : outs) {
|
|
|
|
|
out->mutable_rows()->clear();
|
|
|
|
|
}
|
|
|
|
|
// get rows for outputs
|
|
|
|
|
std::unordered_map<int64_t, size_t> id_to_index;
|
|
|
|
|
for (size_t i = 0; i < ids_rows.size(); ++i) {
|
|
|
|
|