optimize code and comment

wangkuiyi-patch-1
qiaolongfei 7 years ago
parent f031555cfb
commit d6c8d2675c

@ -21,15 +21,17 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddInput("X",
"(LoDTensor) the input tensor with shape{batch_num, N}, N is the "
"size of embedding table")
AddInput(
"X",
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
"size of embedding table")
.AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num.
The values in the input LoDTensor are lookuped from the output of splite_ids_op
The values in the input LoDTensor are lookuped from the output of split_ids_op
Example:
Input:
Ids = [1,2,3,4,5,6]

@ -47,10 +47,10 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
int batch_size = 0;
int embedding_size = 0;
for (auto &input : x_tensors) {
if (embedding_size == 0) {
embedding_size = input->dims()[1];
}
if (framework::product(input->dims()) != 0) {
if (embedding_size == 0) {
embedding_size = input->dims()[1];
}
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
"embedding size of all input should be the same");
batch_size += input->dims()[0];
@ -58,7 +58,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
}
PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0],
"the batch size of ids and embedding value should be the same");
"the batch size of ids and merged embedding value should be the same");
const size_t shard_num = x_tensors.size();
@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
in_indexs[shard_id] += 1;
}
for (int i = 0; i < shard_num; ++i) {
for (size_t i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used");
}

Loading…
Cancel
Save