|
|
|
@ -30,19 +30,16 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_THROW("SplitIds do not support GPU kernel");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto* ids_t = ctx.Input<framework::LoDTensor>("Ids");
|
|
|
|
|
auto& ids_dims = ids_t->dims();
|
|
|
|
|
auto& ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
|
|
|
|
|
const T* ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
|
|
|
|
|
|
|
|
|
|
const T* ids = ids_t->data<T>();
|
|
|
|
|
|
|
|
|
|
const size_t shard_num = outs.size();
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<T>> out_ids;
|
|
|
|
|
out_ids.resize(outs.size());
|
|
|
|
|
|
|
|
|
|
// split id by their shard_num.
|
|
|
|
|
for (size_t i = 0; i < ids_dims[0]; ++i) {
|
|
|
|
|
for (int i = 0; i < ids_dims[0]; ++i) {
|
|
|
|
|
T id = ids[i];
|
|
|
|
|
size_t shard_id = static_cast<size_t>(id) % shard_num;
|
|
|
|
|
out_ids[shard_id].push_back(id);
|
|
|
|
|