|
|
|
@ -47,10 +47,10 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
int batch_size = 0;
|
|
|
|
|
int embedding_size = 0;
|
|
|
|
|
for (auto &input : x_tensors) {
|
|
|
|
|
if (embedding_size == 0) {
|
|
|
|
|
embedding_size = input->dims()[1];
|
|
|
|
|
}
|
|
|
|
|
if (framework::product(input->dims()) != 0) {
|
|
|
|
|
if (embedding_size == 0) {
|
|
|
|
|
embedding_size = input->dims()[1];
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
|
|
|
|
|
"embedding size of all input should be the same");
|
|
|
|
|
batch_size += input->dims()[0];
|
|
|
|
@ -58,7 +58,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
batch_size, ids_dims[0],
|
|
|
|
|
"the batch size of ids and embedding value should be the same");
|
|
|
|
|
"the batch size of ids and merged embedding value should be the same");
|
|
|
|
|
|
|
|
|
|
const size_t shard_num = x_tensors.size();
|
|
|
|
|
|
|
|
|
@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
in_indexs[shard_id] += 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < shard_num; ++i) {
|
|
|
|
|
for (size_t i = 0; i < shard_num; ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
|
|
|
|
|
"after merge, all data in x_tensor should be used");
|
|
|
|
|
}
|
|
|
|
|