|
|
|
@ -14,7 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
@ -69,7 +69,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
|
|
const size_t shard_num = outs.size();
|
|
|
|
|
// get rows for outputs
|
|
|
|
|
std::map<int64_t, size_t> id_to_index;
|
|
|
|
|
std::unordered_map<int64_t, size_t> id_to_index;
|
|
|
|
|
for (size_t i = 0; i < ids_rows.size(); ++i) {
|
|
|
|
|
id_to_index[ids_rows[i]] = i;
|
|
|
|
|
size_t shard_id = static_cast<size_t>(ids_rows[i]) % shard_num;
|
|
|
|
|