|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <unordered_map>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
@ -67,10 +68,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
const auto &ids_rows = ids_selected_rows->rows();
|
|
|
|
const auto &ids_rows = ids_selected_rows->rows();
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
|
const size_t shard_num = outs.size();
|
|
|
|
const size_t shard_num = outs.size();
|
|
|
|
|
|
|
|
for (auto &out : outs) {
|
|
|
|
|
|
|
|
out->mutable_rows()->clear();
|
|
|
|
|
|
|
|
}
|
|
|
|
// get rows for outputs
|
|
|
|
// get rows for outputs
|
|
|
|
for (auto &id : ids_rows) {
|
|
|
|
std::unordered_map<int64_t, size_t> id_to_index;
|
|
|
|
size_t shard_id = static_cast<size_t>(id) % shard_num;
|
|
|
|
for (size_t i = 0; i < ids_rows.size(); ++i) {
|
|
|
|
outs[shard_id]->mutable_rows()->push_back(id);
|
|
|
|
id_to_index[ids_rows[i]] = i;
|
|
|
|
|
|
|
|
size_t shard_id = static_cast<size_t>(ids_rows[i]) % shard_num;
|
|
|
|
|
|
|
|
outs[shard_id]->mutable_rows()->push_back(ids_rows[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int64_t row_width = ids_dims[1];
|
|
|
|
int64_t row_width = ids_dims[1];
|
|
|
@ -80,7 +86,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
{static_cast<int64_t>(out->rows().size()), row_width});
|
|
|
|
{static_cast<int64_t>(out->rows().size()), row_width});
|
|
|
|
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
|
|
|
|
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
|
|
|
|
for (int64_t i = 0; i < ddim[0]; ++i) {
|
|
|
|
for (int64_t i = 0; i < ddim[0]; ++i) {
|
|
|
|
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
|
|
|
|
memcpy(output + i * row_width,
|
|
|
|
|
|
|
|
ids + id_to_index[out->rows()[i]] * row_width,
|
|
|
|
row_width * sizeof(T));
|
|
|
|
row_width * sizeof(T));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|