From 7cebec4b7e876c27af4de5d37101b53260c0bd49 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 10 Jun 2018 22:11:53 +0800 Subject: [PATCH 1/8] init merge_ids_op --- paddle/fluid/operators/merge_ids_op.cc | 97 ++++++++++++++++++++++++++ paddle/fluid/operators/merge_ids_op.h | 83 ++++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 paddle/fluid/operators/merge_ids_op.cc create mode 100644 paddle/fluid/operators/merge_ids_op.h diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc new file mode 100644 index 0000000000..939561509c --- /dev/null +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -0,0 +1,97 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/merge_ids_op.h" + +namespace paddle { +namespace operators { + +class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); + AddInput("X", + "(LoDTensor) the input tensor with shape{batch_num, N}, N is the " + "size of embedding table") + .AsDuplicable(); + AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); + + AddComment(R"DOC( +Merge multi LoDTensor's into one according to Ids's shard num. +The values in the input LoDTensor are lookuped from the output of splite_ids_op +Example: + Input: + Ids = [1,2,3,4,5,6] + X0 = [[0.1 0.2] # 3 + [0.2 0.3]] # 6 + X1 = [[0.3 0.4] # 1 + [0.4 0.5]] # 4 + X2 = [[0.5 0.6] # 2 + [0.6 0.7]] # 5 + + Output: + Out = [[0.3 0.4] # 1 + [0.5 0.6] # 2 + [0.1 0.2] # 3 + [0.4 0.5] # 4 + [0.6 0.7] # 5 + [0.2 0.3]] # 6 +)DOC"); + } +}; + +class MergeIdsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids."); + PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out."); + + auto ids_var_type = ctx->GetInputsVarType("Ids").front(); + auto ids_dims = ctx->GetInputDim("Ids"); + if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + } + auto x_var_type = ctx->GetInputsVarType("X"); + for (auto &var_type : x_var_type) { + PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR, + "input X only support lod tensors"); + } + ctx->ShareLoD("Ids", "Out"); + } +}; + +class MergeIdsOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto *input_var = block->Var(op_desc.Input("Ids")[0]); + for (auto &out_var : op_desc.Output("Out")) { + block->Var(out_var)->SetType(input_var->GetType()); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker, + ops::MergeIdsOpInferVarType); +REGISTER_OP_CPU_KERNEL( + merge_ids, ops::MergeIdsOpKernel, + ops::MergeIdsOpKernel); diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h new file mode 100644 index 0000000000..fd5b542ceb --- /dev/null +++ b/paddle/fluid/operators/merge_ids_op.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace paddle { +namespace operators { + +template +class MergeIdsOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto place = ctx.GetPlace(); + if (!platform::is_cpu_place(place)) { + PADDLE_THROW("MergeIds do not support GPU kernel"); + } + + const auto *ids_var = ctx.InputVar("Ids"); + PADDLE_ENFORCE(ids_var->IsType(), + "only support to merge Ids of LoDTensor"); + + const auto &ids_tensor = ids_var->Get(); + const auto &ids_dims = ids_tensor.dims(); + const T *ids = ids_tensor.data(); + + auto x_tensors = ctx.MultiInput("X"); + + auto *out = ctx.Output("Out"); + + int batch_size = 0; + int embedding_size = 0; + for (auto &input : x_tensors) { + if (embedding_size == 0) { + embedding_size = input->dims()[1]; + } + PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], + "embedding size of all input should be the same"); + batch_size += input->dims()[0]; + } + PADDLE_ENFORCE_EQ( + batch_size, ids_dims[0], + "the batch size of ids and embedding value should be the same"); + + const size_t shard_num = x_tensors.size(); + + if (shard_num == 1) { + VLOG(3) << "only one shard, we can copy the data directly"; + TensorCopy(ids_tensor, place, out); + } else { + std::vector in_indexs(shard_num, 0); + auto *out_data = out->mutable_data(ids_dims, place); + // copy data from ins[shard_num] to out. + for (int i = 0; i < ids_dims[0]; ++i) { + T id = ids[i]; + size_t shard_id = static_cast(id) % shard_num; + int index = in_indexs[shard_id]; + memcpy(out_data + embedding_size * i, + x_tensors[shard_id]->data() + index * embedding_size, + sizeof(T) * embedding_size); + in_indexs[shard_id] += 1; + } + } + } +}; + +} // namespace operators +} // namespace paddle From 509cb0bc76ea3b22423e293f608d4956a63deda7 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 10 Jun 2018 23:31:41 +0800 Subject: [PATCH 2/8] add unit test, pass the unit test --- paddle/fluid/operators/merge_ids_op.cc | 12 +++++- paddle/fluid/operators/merge_ids_op.h | 23 +++++++---- .../tests/unittests/test_merge_ids_op.py | 38 +++++++++++++++++++ 3 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_merge_ids_op.py diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index 939561509c..bae649adec 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -73,6 +73,15 @@ class MergeIdsOp : public framework::OperatorWithKernel { } ctx->ShareLoD("Ids", "Out"); } + + private: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.MultiInput("X").front()->type()), + ctx.GetPlace()); + } }; class MergeIdsOpInferVarType : public framework::VarTypeInference { @@ -93,5 +102,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker, ops::MergeIdsOpInferVarType); REGISTER_OP_CPU_KERNEL( - merge_ids, ops::MergeIdsOpKernel, - ops::MergeIdsOpKernel); + merge_ids, ops::MergeIdsOpKernel); diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h index fd5b542ceb..065368f8dd 100644 --- a/paddle/fluid/operators/merge_ids_op.h +++ b/paddle/fluid/operators/merge_ids_op.h @@ -30,6 +30,7 @@ class MergeIdsOpKernel : public framework::OpKernel { if (!platform::is_cpu_place(place)) { PADDLE_THROW("MergeIds do not support GPU kernel"); } + VLOG(3) << "run in MergeIdsOpKernel"; const auto *ids_var = ctx.InputVar("Ids"); PADDLE_ENFORCE(ids_var->IsType(), @@ -37,7 +38,7 @@ class MergeIdsOpKernel : public framework::OpKernel { const auto &ids_tensor = ids_var->Get(); const auto &ids_dims = ids_tensor.dims(); - const T *ids = ids_tensor.data(); + const int64_t *ids = ids_tensor.data(); auto x_tensors = ctx.MultiInput("X"); @@ -49,9 +50,11 @@ class MergeIdsOpKernel : public framework::OpKernel { if (embedding_size == 0) { embedding_size = input->dims()[1]; } - PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], - "embedding size of all input should be the same"); - batch_size += input->dims()[0]; + if (framework::product(input->dims()) != 0) { + PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], + "embedding size of all input should be the same"); + batch_size += input->dims()[0]; + } } PADDLE_ENFORCE_EQ( batch_size, ids_dims[0], @@ -61,13 +64,14 @@ class MergeIdsOpKernel : public framework::OpKernel { if (shard_num == 1) { VLOG(3) << "only one shard, we can copy the data directly"; - TensorCopy(ids_tensor, place, out); + TensorCopy(*x_tensors[0], place, out); } else { std::vector in_indexs(shard_num, 0); - auto *out_data = out->mutable_data(ids_dims, place); + auto *out_data = out->mutable_data( + framework::make_ddim({batch_size, embedding_size}), place); // copy data from ins[shard_num] to out. for (int i = 0; i < ids_dims[0]; ++i) { - T id = ids[i]; + int64_t id = ids[i]; size_t shard_id = static_cast(id) % shard_num; int index = in_indexs[shard_id]; memcpy(out_data + embedding_size * i, @@ -75,6 +79,11 @@ class MergeIdsOpKernel : public framework::OpKernel { sizeof(T) * embedding_size); in_indexs[shard_id] += 1; } + + for (int i = 0; i < shard_num; ++i) { + PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0], + "after merge, all data in x_tensor should be used"); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py new file mode 100644 index 0000000000..f209bdf30f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py @@ -0,0 +1,38 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestMergeIdsOp(OpTest): + def setUp(self): + self.op_type = "merge_ids" + ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32') + x1 = np.array([]).astype('float32') + x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6], + [0.5, 0.6]]).astype('float32') + out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3], + [0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32') + self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() From a941786393e5b840a98215d81ba09c6cf0995ca1 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 11 Jun 2018 09:40:19 +0800 Subject: [PATCH 3/8] replace concat_op with merge_ids_op --- .../fluid/transpiler/distribute_transpiler.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 27992df462..ed4158bc4c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -618,7 +618,7 @@ class DistributeTranspiler: if op.type == LOOKUP_TABLE_TYPE: continue_search_lookup_table_op = True - op_index = list(all_ops).index(op) + lookup_table_op_index = list(all_ops).index(op) ids_name = op.input("Ids") out_name = op.output("Out") @@ -637,7 +637,7 @@ class DistributeTranspiler: # insert split_ids_op program.global_block().insert_op( - index=op_index, + index=lookup_table_op_index, type="split_ids", inputs={ 'Ids': [ @@ -649,7 +649,7 @@ class DistributeTranspiler: # insert prefetch_op program.global_block().insert_op( - index=op_index + 1, + index=lookup_table_op_index + 1, type="prefetch", inputs={'X': self.prefetch_input_vars}, outputs={"Out": self.prefetch_output_vars}, @@ -660,16 +660,21 @@ class DistributeTranspiler: # insert concat_op program.global_block().insert_op( - index=op_index + 2, - type="concat", - inputs={'X': self.prefetch_output_vars}, + index=lookup_table_op_index + 2, + type="merge_ids", + inputs={ + 'Ids': [ + program.global_block().vars[varname] + for varname in ids_name + ], + 'X': self.prefetch_output_vars + }, outputs={ "Out": [ program.global_block().vars[varname] for varname in out_name ] - }, - attrs={"axis": 0}) + }) # delete lookup_table_op delete_ops(program.global_block(), [op]) From 0485405b3d9d8dcf45139c63c370ca9124d26744 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 11 Jun 2018 11:04:02 +0800 Subject: [PATCH 4/8] add more debug string --- paddle/fluid/framework/executor.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 3d68c5fb87..15af9c4090 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -317,8 +317,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } for (auto& op : ctx->ops_) { - VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); + VLOG(4) << place_ << " " << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); + // NOTE! Please do not delete this line, it's usefull because the debug + // string before and after op.run are different, after run the output + // will have right shape which is usefull for debug. + VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " From d6c8d2675cd07ae679f922fef83a0a089d04c2ee Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 12 Jun 2018 23:26:58 +0800 Subject: [PATCH 5/8] optimize code and comment --- paddle/fluid/operators/merge_ids_op.cc | 10 ++++++---- paddle/fluid/operators/merge_ids_op.h | 10 +++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index bae649adec..f3940231d7 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -21,15 +21,17 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); - AddInput("X", - "(LoDTensor) the input tensor with shape{batch_num, N}, N is the " - "size of embedding table") + AddInput( + "X", + "(LoDTensors) multi input tensor with shape{batch_num, N}, N is the " + "size of embedding table") .AsDuplicable(); AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); AddComment(R"DOC( Merge multi LoDTensor's into one according to Ids's shard num. -The values in the input LoDTensor are lookuped from the output of splite_ids_op +The values in the input LoDTensor are lookuped from the output of split_ids_op + Example: Input: Ids = [1,2,3,4,5,6] diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h index 065368f8dd..83712a8519 100644 --- a/paddle/fluid/operators/merge_ids_op.h +++ b/paddle/fluid/operators/merge_ids_op.h @@ -47,10 +47,10 @@ class MergeIdsOpKernel : public framework::OpKernel { int batch_size = 0; int embedding_size = 0; for (auto &input : x_tensors) { - if (embedding_size == 0) { - embedding_size = input->dims()[1]; - } if (framework::product(input->dims()) != 0) { + if (embedding_size == 0) { + embedding_size = input->dims()[1]; + } PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], "embedding size of all input should be the same"); batch_size += input->dims()[0]; @@ -58,7 +58,7 @@ class MergeIdsOpKernel : public framework::OpKernel { } PADDLE_ENFORCE_EQ( batch_size, ids_dims[0], - "the batch size of ids and embedding value should be the same"); + "the batch size of ids and merged embedding value should be the same"); const size_t shard_num = x_tensors.size(); @@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel { in_indexs[shard_id] += 1; } - for (int i = 0; i < shard_num; ++i) { + for (size_t i = 0; i < shard_num; ++i) { PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0], "after merge, all data in x_tensor should be used"); } From 7ebef493d58eaf0cbf74bf41e683c3048222b623 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 13 Jun 2018 14:30:03 +0800 Subject: [PATCH 6/8] add row_size for selected rows in DebugStringEx --- paddle/fluid/framework/operator.cc | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index c633a2f847..d22ac66c5c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -69,6 +69,19 @@ static DDim GetDims(const Scope& scope, const std::string& name, } } +static int GetRowSize(const Scope& scope, const std::string& name) { + Variable* var = scope.FindVar(name); + if (var == nullptr) { + return -1; + } + + if (var->IsType()) { + return var->Get().rows().size(); + } + + return -1; +} + static LoD GetLoD(const Scope& scope, const std::string& name) { Variable* var = scope.FindVar(name); auto default_lod = LoD({{}}); @@ -153,6 +166,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { for (size_t i = 0; i < input.second.size(); ++i) { ss << input.second[i]; if (scope) { + int row_size = GetRowSize(*scope, input.second[i]); + if (row_size >= 0) { + ss << "[row_size=" << row_size << "]"; + } ss << "[" << GetDims(*scope, input.second[i], true) << "]"; ss << "(" << GetLoD(*scope, input.second[i]) << ")"; } @@ -173,6 +190,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { for (size_t i = 0; i < output.second.size(); ++i) { ss << output.second[i]; if (scope) { + int row_size = GetRowSize(*scope, output.second[i]); + if (row_size >= 0) { + ss << "[row_size=" << row_size << "]"; + } ss << "[" << GetDims(*scope, output.second[i], true) << "]"; ss << "(" << GetLoD(*scope, output.second[i]) << ")"; } From 2e48ab623e788081b58ae91430e4355fe4c578b9 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 13 Jun 2018 15:17:45 +0800 Subject: [PATCH 7/8] add more detailed comment --- paddle/fluid/operators/merge_ids_op.cc | 55 ++++++++++++++++++-------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index f3940231d7..59cd734367 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -30,25 +30,46 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Merge multi LoDTensor's into one according to Ids's shard num. -The values in the input LoDTensor are lookuped from the output of split_ids_op + + +split_ids_op -> prefetch_op -> merge_ids_op + + +merge_ids_op should be used after split_ids_op and prefetch_op, split_ids_op + will split input Ids into multiple tensors according to Id's shard number. +prefetch_op will send them to parameter server to prefetch embedding value +back. During split, the order of ids is disordered. In merge_ids_op we use +the original Ids to restore the order of the fetched embedding value and + also pass the lod information to the merged output. + Example: - Input: - Ids = [1,2,3,4,5,6] - X0 = [[0.1 0.2] # 3 - [0.2 0.3]] # 6 - X1 = [[0.3 0.4] # 1 - [0.4 0.5]] # 4 - X2 = [[0.5 0.6] # 2 - [0.6 0.7]] # 5 - - Output: - Out = [[0.3 0.4] # 1 - [0.5 0.6] # 2 - [0.1 0.2] # 3 - [0.4 0.5] # 4 - [0.6 0.7] # 5 - [0.2 0.3]] # 6 + + Ids = [1,2,3,4,5,6] # 3 shared + +split_ids_op -> + + Id0 = [3, 6] + Id1 = [1, 4] + Id2 = [2, 5] + +prefetch_op -> + + X0 = [[0.3 0.3] # 3 + [0.6 0.6]] # 6 + X1 = [[0.1 0.1] # 1 + [0.4 0.4]] # 4 + X2 = [[0.2 0.2] # 2 + [0.5 0.5]] # 5 + +merge_ids_op -> + + Out = [[0.1 0.1] # 1 + [0.2 0.2] # 2 + [0.3 0.3] # 3 + [0.4 0.4] # 4 + [0.5 0.5] # 5 + [0.6 0.6]] # 6 )DOC"); } }; From e6f54d5aa2c637c5719add11abda07bbe82aa6c1 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 13 Jun 2018 15:20:08 +0800 Subject: [PATCH 8/8] update comment --- paddle/fluid/operators/merge_ids_op.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index 59cd734367..c6ec4ab047 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -49,9 +49,9 @@ Example: split_ids_op -> - Id0 = [3, 6] - Id1 = [1, 4] - Id2 = [2, 5] + Id0 = [3, 6] # id % 3 == 0 + Id1 = [1, 4] # id % 3 == 1 + Id2 = [2, 5] # id % 3 == 2 prefetch_op ->