|
|
|
@ -43,9 +43,9 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
|
|
|
|
|
"the number of Ids and Out should be the same");
|
|
|
|
|
|
|
|
|
|
size_t row_ids_size = 0;
|
|
|
|
|
int row_size = 0;
|
|
|
|
|
int embedding_size = 0;
|
|
|
|
|
int64_t row_ids_size = 0;
|
|
|
|
|
int64_t row_size = 0;
|
|
|
|
|
int64_t embedding_size = 0;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < x_tensors.size(); ++i) {
|
|
|
|
|
const auto *x_tensor = x_tensors[i];
|
|
|
|
@ -69,7 +69,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (size_t i = 0; i < x_tensors.size(); ++i) {
|
|
|
|
|
const auto *row_id = row_ids[i];
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < row_id->numel(); ++j) {
|
|
|
|
|
for (auto j = 0; j < row_id->numel(); ++j) {
|
|
|
|
|
int64_t key = row_id->data<int64_t>()[j];
|
|
|
|
|
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
|
|
|
|
|
selected_rows_idx_map.insert(std::make_pair(key, val));
|
|
|
|
@ -84,13 +84,13 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
out->set_lod(out_ids->lod());
|
|
|
|
|
|
|
|
|
|
int nums = static_cast<int>(out_ids->dims()[0]);
|
|
|
|
|
auto nums = out_ids->dims()[0];
|
|
|
|
|
auto *out_data = out->mutable_data<T>(
|
|
|
|
|
framework::make_ddim({nums, embedding_size}), place);
|
|
|
|
|
for (int j = 0; j < nums; ++j) {
|
|
|
|
|
int id = out_ids->data<int64_t>()[j];
|
|
|
|
|
auto row_tuple = selected_rows_idx_map[id];
|
|
|
|
|
int64_t row_idx = std::get<1>(row_tuple);
|
|
|
|
|
for (auto j = 0; j < nums; ++j) {
|
|
|
|
|
auto id = out_ids->data<int64_t>()[j];
|
|
|
|
|
auto row_tuple = selected_rows_idx_map.at(id);
|
|
|
|
|
auto row_idx = std::get<1>(row_tuple);
|
|
|
|
|
const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
|
|
|
|
|
|
|
|
|
|
memcpy(out_data + embedding_size * j,
|
|
|
|
|