fix bug in merge_ids (#15503)

* fix mistakes in merge_ids, test=develop
inference-pre-release-gpu
tangwei12 6 years ago committed by GitHub
parent a7ba07d7ef
commit 981fc2bdba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,9 +43,9 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
"the number of Ids and Out should be the same");
size_t row_ids_size = 0;
int row_size = 0;
int embedding_size = 0;
int64_t row_ids_size = 0;
int64_t row_size = 0;
int64_t embedding_size = 0;
for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *x_tensor = x_tensors[i];
@ -69,7 +69,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *row_id = row_ids[i];
for (int j = 0; j < row_id->numel(); ++j) {
for (auto j = 0; j < row_id->numel(); ++j) {
int64_t key = row_id->data<int64_t>()[j];
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
selected_rows_idx_map.insert(std::make_pair(key, val));
@ -84,13 +84,13 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
out->set_lod(out_ids->lod());
int nums = static_cast<int>(out_ids->dims()[0]);
auto nums = out_ids->dims()[0];
auto *out_data = out->mutable_data<T>(
framework::make_ddim({nums, embedding_size}), place);
for (int j = 0; j < nums; ++j) {
int id = out_ids->data<int64_t>()[j];
auto row_tuple = selected_rows_idx_map[id];
int64_t row_idx = std::get<1>(row_tuple);
for (auto j = 0; j < nums; ++j) {
auto id = out_ids->data<int64_t>()[j];
auto row_tuple = selected_rows_idx_map.at(id);
auto row_idx = std::get<1>(row_tuple);
const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
memcpy(out_data + embedding_size * j,

Loading…
Cancel
Save