From 1509ce663881a202c53bb83e78b974e507e18af6 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 9 Mar 2018 19:25:04 +0800 Subject: [PATCH 1/9] enhancement look_up_table --- paddle/fluid/operators/lookup_table_op.cc | 9 ++++-- paddle/fluid/operators/lookup_table_op.cu | 28 ++++++++++++++++--- paddle/fluid/operators/lookup_table_op.h | 34 +++++++++++++++++++---- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 3acdca17af..461e5bd2d3 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -33,8 +33,13 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + auto ids_var_type = ctx->GetInputsVarType("Ids").front(); + // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. + // Maybe near future we will add concat_rows op. + if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + } ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->ShareLoD("Ids", /*->*/ "Out"); diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 923340f461..125e0f9441 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -74,14 +74,34 @@ class LookupTableCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* table_t = context.Input("W"); - auto* ids_t = context.Input("Ids"); - auto* output_t = context.Output("Out"); int64_t padding_idx = context.Attr("padding_idx"); + auto* ids_var = context.InputVar("Ids"); // int tensor + + int64_t* ids; + int64_t K; + framework::Tensor* output_t; + + // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. + // Maybe near future we will add concat_rows op. + if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = context.Output("Out"); // float tensor + ids = const_cast(ids_t->data()); + K = ids_t->numel(); + } else if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = const_cast( + &(context.Output("Out") + ->value())); // float tensor + ids = const_cast(ids_t->rows().CUDAData(context.GetPlace())); + K = ids_t->rows().size(); + output_t->Resize({K, table_t->dims()[1]}); + } else { + PADDLE_THROW("Unsupported Variable Type of Ids"); + } size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; - size_t K = ids_t->numel(); - auto* ids = ids_t->data(); auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index d88b034e91..b2439c6837 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -22,6 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; @@ -29,25 +30,46 @@ template class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* table_t = context.Input("W"); // float tensor - auto* ids_t = context.Input("Ids"); // int tensor - auto* output_t = context.Output("Out"); // float tensor + auto* table_t = context.Input("W"); // float tensor + auto* ids_var = context.InputVar("Ids"); // int tensor + + int64_t* ids; + int64_t ids_numel; + Tensor* output_t; + + // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. + // Maybe near future we will add concat_rows op. + if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = context.Output("Out"); + ids = const_cast(ids_t->data()); + ids_numel = ids_t->numel(); + } else if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = + const_cast(&(context.Output("Out")->value())); + ids = const_cast(ids_t->rows().data()); + ids_numel = ids_t->rows().size(); + output_t->Resize({ids_numel, table_t->dims()[1]}); + } else { + PADDLE_THROW("Unsupported Variable Type of Ids"); + } + int64_t padding_idx = context.Attr("padding_idx"); int N = table_t->dims()[0]; int D = table_t->dims()[1]; - auto* ids = ids_t->data(); auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); if (padding_idx == -1) { - for (int64_t i = 0; i < ids_t->numel(); ++i) { + for (int64_t i = 0; i < ids_numel; ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); } } else { - for (int64_t i = 0; i < ids_t->numel(); ++i) { + for (int64_t i = 0; i < ids_numel; ++i) { if (ids[i] == padding_idx) { memset(output + i * D, 0, D * sizeof(T)); } else { From f1c3ecb2b2859bbfaac7fd1383f03ff9d5c93207 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 10 Mar 2018 01:28:32 +0800 Subject: [PATCH 2/9] add concat rows --- paddle/fluid/operators/lookup_table_op.cc | 48 +++++++++++- paddle/fluid/operators/lookup_table_op.cu | 20 ++--- paddle/fluid/operators/lookup_table_op.h | 11 +-- .../tests/unittests/test_concat_rows_op.py | 76 +++++++++++++++++++ 4 files changed, 136 insertions(+), 19 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_concat_rows_op.py diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 461e5bd2d3..f32b8896d4 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -34,8 +34,9 @@ class LookupTableOp : public framework::OperatorWithKernel { auto ids_dims = ctx->GetInputDim("Ids"); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + // lookup_table and concat_rows use the same InferShape, for lookup_table, + // ids_var_type should be LoDTensor, for concat_rows, it should be + // SelectedRows. if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[1], 1); @@ -90,6 +91,44 @@ or not. And the output only shares the LoD information with input Ids. } }; +class ConcatRowsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ConcatRowsOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("W", + "(Tensor) The input tensor of concat_rows operator. " + "The rank of this tensor is 2."); + AddInput( + "Ids", + "(SelectedRows) The rows of Ids contains the index to be looked up " + "in W."); + AddOutput("Out", + "(SelectedRows or Tensor) The result of concatenating, which " + "have the same type as W."); + AddAttr("is_sparse", + "(boolean, default true) This attribution is invalid, it's " + "only used by `Lookup Table Operator`.") + .SetDefault(true); + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(-1); + + AddComment(R"DOC( +ConcatRows Operator. + +This operator is used to perform lookups on the W(dense tensor) according to +rows contained by Idx(sparse tensor), then concatenates them into a sparse +tensor or dense tensor. + +The type of Ids(Input) is SelectedRows. + +)DOC"); + } +}; + class LookupTableOpGradDescMaker : public framework::DefaultGradOpDescMaker { using ::paddle::framework::DefaultGradOpDescMaker< @@ -150,3 +189,8 @@ REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel, ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel, ops::LookupTableGradKernel); + +// concat_rows is used by regularization and it doesn't have gradient operation. +REGISTER_OPERATOR(concat_rows, ops::LookupTableOp, ops::ConcatRowsOpMaker); +REGISTER_OP_CPU_KERNEL(concat_rows, ops::LookupTableKernel, + ops::LookupTableKernel); diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 125e0f9441..b880d86cf6 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -79,20 +79,17 @@ class LookupTableCUDAKernel : public framework::OpKernel { int64_t* ids; int64_t K; - framework::Tensor* output_t; + auto* output_t = context.Output("Out"); // float tensor; - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. - if (ids_var->IsType()) { + // lookup_table and concat_rows use the same kernel, for lookup_table, + // ids_var_type should be LoDTensor, for concat_rows, ids_var_type and + // out_var_type should be SelectedRows. + if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out"); // float tensor ids = const_cast(ids_t->data()); K = ids_t->numel(); - } else if (ids_var->IsType()) { - auto* ids_t = context.Input("Ids"); - output_t = const_cast( - &(context.Output("Out") - ->value())); // float tensor + } else if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); ids = const_cast(ids_t->rows().CUDAData(context.GetPlace())); K = ids_t->rows().size(); output_t->Resize({K, table_t->dims()[1]}); @@ -194,3 +191,6 @@ REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel, REGISTER_OP_CUDA_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel, ops::LookupTableGradCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(concat_rows, ops::LookupTableCUDAKernel, + ops::LookupTableCUDAKernel); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index b2439c6837..32a0085e06 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -35,19 +35,16 @@ class LookupTableKernel : public framework::OpKernel { int64_t* ids; int64_t ids_numel; - Tensor* output_t; - - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + auto* output_t = context.Output("Out"); + // lookup_table and concat_rows use the same kernel, for lookup_table, + // ids_var_type should be LoDTensor, for concat_rows, ids_var_type and + // out_var_type should be SelectedRows. if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out"); ids = const_cast(ids_t->data()); ids_numel = ids_t->numel(); } else if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = - const_cast(&(context.Output("Out")->value())); ids = const_cast(ids_t->rows().data()); ids_numel = ids_t->rows().size(); output_t->Resize({ids_numel, table_t->dims()[1]}); diff --git a/python/paddle/fluid/tests/unittests/test_concat_rows_op.py b/python/paddle/fluid/tests/unittests/test_concat_rows_op.py new file mode 100644 index 0000000000..6dd25c2e02 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_concat_rows_op.py @@ -0,0 +1,76 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from op_test import OpTest + + +class TestConcatRowsOp(OpTest): + def check_with_place(self, place): + scope = core.Scope() + + # create and initialize Grad Variable + height = 10 + rows = [0, 4, 4, 7] + row_numel = 12 + + ids_selected_rows = scope.var('Ids').get_selected_rows() + ids_selected_rows.set_height(height) + ids_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + ids_tensor = ids_selected_rows.get_tensor() + ids_tensor.set(np_array, place) + + # create and initialize W Variable + W = scope.var('W').get_tensor() + W_array = np.full((height, row_numel), 1.0).astype("float32") + for i in range(height): + W_array[i] *= i + W.set(W_array, place) + + Out = scope.var('Out').get_selected_rows() + Out_array = np.full((len(rows), row_numel), -1.0).astype("float32") + Out.set_height(height) + Out.set_rows(rows) + Out_tensor = Out.get_tensor() + Out_tensor.set(Out_array, place) + + # create and run concat_rows_op operator + concat_rows_op = Operator( + "concat_rows", + W='W', + Ids='Ids', + Out='Out', + attrs={'is_sparse': True}) + concat_rows_op.run(scope, place) + + # get and compare result + result_array = np.array(Out_tensor) + + for idx, row in enumerate(rows): + assert (row == result_array[idx]).all() + + def test_concat_rows(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + +if __name__ == "__main__": + unittest.main() From b9397b26680710c924f6e59bd7988eeb4e161fc1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 13 Mar 2018 11:30:17 +0800 Subject: [PATCH 3/9] remove concat_rows --- paddle/fluid/operators/lookup_table_op.cc | 83 ++++++------------- paddle/fluid/operators/lookup_table_op.cu | 12 +-- paddle/fluid/operators/lookup_table_op.h | 13 +-- .../tests/unittests/test_concat_rows_op.py | 76 ----------------- .../tests/unittests/test_lookup_table_op.py | 49 +++++++++++ 5 files changed, 88 insertions(+), 145 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/test_concat_rows_op.py diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index f32b8896d4..753553a686 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -34,9 +34,12 @@ class LookupTableOp : public framework::OperatorWithKernel { auto ids_dims = ctx->GetInputDim("Ids"); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - // lookup_table and concat_rows use the same InferShape, for lookup_table, - // ids_var_type should be LoDTensor, for concat_rows, it should be - // SelectedRows. + + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W + // and it must be a column vector with rank = 2 while the 2nd dimension + // size must be 1, when Ids's type is SelectedRows, the rows of Ids + // contains the ids to be looked up in W; if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[1], 1); @@ -60,70 +63,41 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("W", - "An input represents embedding tensors, " + "(Tensor) The input represents embedding tensors, " "which is a learnable parameter."); - AddInput("Ids", - "An input with type int32 or int64 " - "contains the ids to be looked up in W. " - "Ids must be a column vector with rank = 2. " - "The 2nd dimension size must be 1."); - AddOutput("Out", "The lookup results, which have the same type as W."); - AddAttr("is_sparse", - "(boolean, default false) " - "Sparse update") - .SetDefault(false); - AddAttr("padding_idx", - "(int64, default -1) " - "If the value is -1, it makes no effect to lookup. " - "Otherwise the given value indicates padding the output " - "with zeros whenever lookup encounters it in Ids.") - .SetDefault(-1); - AddComment(R"DOC( -Lookup Table Operator. - -This operator is used to perform lookups on the parameter W, -then concatenated into a dense tensor. - -The input Ids can carry the LoD (Level of Details) information, -or not. And the output only shares the LoD information with input Ids. - -)DOC"); - } -}; - -class ConcatRowsOpMaker : public framework::OpProtoAndCheckerMaker { - public: - ConcatRowsOpMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("W", - "(Tensor) The input tensor of concat_rows operator. " - "The rank of this tensor is 2."); AddInput( "Ids", - "(SelectedRows) The rows of Ids contains the index to be looked up " + "(Tensor or SelectedRows) Ids's type can be Tensor or " + "SelectedRows, when Ids's type is Tensor, this tensor contains " + "the ids to be looked up in W and it must be a column vector with " + "rank = 2 while the 2nd dimension size must be 1; when Ids's type is " + "SelectedRows, the rows of Ids contains the ids to be looked up " "in W."); AddOutput("Out", - "(SelectedRows or Tensor) The result of concatenating, which " - "have the same type as W."); + "(Tensor or SelectedRows) The lookup results, which have the " + "same type as W."); AddAttr("is_sparse", - "(boolean, default true) This attribution is invalid, it's " - "only used by `Lookup Table Operator`.") - .SetDefault(true); + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); AddAttr("padding_idx", "(int64, default -1) " "If the value is -1, it makes no effect to lookup. " "Otherwise the given value indicates padding the output " "with zeros whenever lookup encounters it in Ids.") .SetDefault(-1); - AddComment(R"DOC( -ConcatRows Operator. +Lookup Table Operator. -This operator is used to perform lookups on the W(dense tensor) according to -rows contained by Idx(sparse tensor), then concatenates them into a sparse -tensor or dense tensor. +This operator is used to perform lookups on the parameter W, +then concatenated into a dense or sparse tensor. -The type of Ids(Input) is SelectedRows. +The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's +type is SelectedRows, the rows of Ids contains the ids to be looked up in W; +when Ids's type is Tensor, this tensor contains the ids to be looked up in W +and it must be a column vector with rank = 2 while the 2nd dimension size must be 1, +at this time, Ids can carry the LoD (Level of Details) information, or not, and +the output only shares the LoD information with input Ids. )DOC"); } @@ -189,8 +163,3 @@ REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel, ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel, ops::LookupTableGradKernel); - -// concat_rows is used by regularization and it doesn't have gradient operation. -REGISTER_OPERATOR(concat_rows, ops::LookupTableOp, ops::ConcatRowsOpMaker); -REGISTER_OP_CPU_KERNEL(concat_rows, ops::LookupTableKernel, - ops::LookupTableKernel); diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index b880d86cf6..7dce6ae558 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -74,16 +74,16 @@ class LookupTableCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* table_t = context.Input("W"); + auto* output_t = context.Output("Out"); int64_t padding_idx = context.Attr("padding_idx"); - auto* ids_var = context.InputVar("Ids"); // int tensor + auto* ids_var = context.InputVar("Ids"); int64_t* ids; int64_t K; - auto* output_t = context.Output("Out"); // float tensor; - - // lookup_table and concat_rows use the same kernel, for lookup_table, - // ids_var_type should be LoDTensor, for concat_rows, ids_var_type and - // out_var_type should be SelectedRows. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); ids = const_cast(ids_t->data()); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 32a0085e06..8d2839d1b6 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -30,15 +30,16 @@ template class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* table_t = context.Input("W"); // float tensor - auto* ids_var = context.InputVar("Ids"); // int tensor + auto* table_t = context.Input("W"); + auto* output_t = context.Output("Out"); + auto* ids_var = context.InputVar("Ids"); int64_t* ids; int64_t ids_numel; - auto* output_t = context.Output("Out"); - // lookup_table and concat_rows use the same kernel, for lookup_table, - // ids_var_type should be LoDTensor, for concat_rows, ids_var_type and - // out_var_type should be SelectedRows. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); ids = const_cast(ids_t->data()); diff --git a/python/paddle/fluid/tests/unittests/test_concat_rows_op.py b/python/paddle/fluid/tests/unittests/test_concat_rows_op.py deleted file mode 100644 index 6dd25c2e02..0000000000 --- a/python/paddle/fluid/tests/unittests/test_concat_rows_op.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -import paddle.fluid.core as core -from paddle.fluid.op import Operator -from op_test import OpTest - - -class TestConcatRowsOp(OpTest): - def check_with_place(self, place): - scope = core.Scope() - - # create and initialize Grad Variable - height = 10 - rows = [0, 4, 4, 7] - row_numel = 12 - - ids_selected_rows = scope.var('Ids').get_selected_rows() - ids_selected_rows.set_height(height) - ids_selected_rows.set_rows(rows) - np_array = np.ones((len(rows), row_numel)).astype("float32") - ids_tensor = ids_selected_rows.get_tensor() - ids_tensor.set(np_array, place) - - # create and initialize W Variable - W = scope.var('W').get_tensor() - W_array = np.full((height, row_numel), 1.0).astype("float32") - for i in range(height): - W_array[i] *= i - W.set(W_array, place) - - Out = scope.var('Out').get_selected_rows() - Out_array = np.full((len(rows), row_numel), -1.0).astype("float32") - Out.set_height(height) - Out.set_rows(rows) - Out_tensor = Out.get_tensor() - Out_tensor.set(Out_array, place) - - # create and run concat_rows_op operator - concat_rows_op = Operator( - "concat_rows", - W='W', - Ids='Ids', - Out='Out', - attrs={'is_sparse': True}) - concat_rows_op.run(scope, place) - - # get and compare result - result_array = np.array(Out_tensor) - - for idx, row in enumerate(rows): - assert (row == result_array[idx]).all() - - def test_concat_rows(self): - places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) - for place in places: - self.check_with_place(place) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 03a5bd24a1..8bd8913faf 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -14,6 +14,8 @@ import unittest import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator from op_test import OpTest @@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): pass +# Testing look_up_table when Ids's type is SelectedRows. +class TestLookupTableIdsIsSelectedRows(OpTest): + def check_with_place(self, place): + scope = core.Scope() + + height = 10 + rows = [0, 4, 4, 7] + row_numel = 12 + + ids_selected_rows = scope.var('Ids').get_selected_rows() + ids_selected_rows.set_height(height) + ids_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + ids_tensor = ids_selected_rows.get_tensor() + ids_tensor.set(np_array, place) + + W = scope.var('W').get_tensor() + W_array = np.full((height, row_numel), 1.0).astype("float32") + for i in range(height): + W_array[i] *= i + W.set(W_array, place) + + Out = scope.var('Out').get_selected_rows() + Out_array = np.full((len(rows), row_numel), -1.0).astype("float32") + Out.set_height(height) + Out.set_rows(rows) + Out_tensor = Out.get_tensor() + Out_tensor.set(Out_array, place) + + # create and run concat_rows_op operator + concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out') + concat_rows_op.run(scope, place) + + # get and compare result + result_array = np.array(Out_tensor) + + for idx, row in enumerate(rows): + assert (row == result_array[idx]).all() + + def test_concat_rows(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main() From 92e2207e183dcc3c66eb94c1f2c45f25f7b2bdc2 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 13 Mar 2018 18:57:44 +0800 Subject: [PATCH 4/9] refine doc --- paddle/fluid/operators/lookup_table_op.cc | 39 +++++++++++++++-------- paddle/fluid/operators/lookup_table_op.cu | 12 +++---- paddle/fluid/operators/lookup_table_op.h | 14 ++++---- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 461e5bd2d3..50eeadab72 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -34,8 +34,11 @@ class LookupTableOp : public framework::OperatorWithKernel { auto ids_dims = ctx->GetInputDim("Ids"); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W + // and it must be a column vector with rank = 2 while the 2nd dimension + // size must be 1, when Ids's type is SelectedRows, the rows of Ids + // contains the ids to be looked up in W; if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[1], 1); @@ -59,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("W", - "An input represents embedding tensors, " + "(Tensor) The input represents embedding tensors, " "which is a learnable parameter."); - AddInput("Ids", - "An input with type int32 or int64 " - "contains the ids to be looked up in W. " - "Ids must be a column vector with rank = 2. " - "The 2nd dimension size must be 1."); - AddOutput("Out", "The lookup results, which have the same type as W."); + AddInput( + "Ids", + "(Tensor or SelectedRows) Ids's type can be Tensor or " + "SelectedRows, when Ids's type is Tensor, this tensor contains " + "the ids to be looked up in W and it must be a column vector with " + "rank = 2 while the 2nd dimension size must be 1; when Ids's type is " + "SelectedRows, the rows of Ids contains the ids to be looked up " + "in W."); + AddOutput("Out", + "(Tensor or SelectedRows) The lookup results, which have the " + "same type as W."); AddAttr("is_sparse", "(boolean, default false) " - "Sparse update") + "Sparse update.") .SetDefault(false); AddAttr("padding_idx", "(int64, default -1) " @@ -81,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { Lookup Table Operator. This operator is used to perform lookups on the parameter W, -then concatenated into a dense tensor. +then concatenated into a dense or sparse tensor. + +The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's +type is SelectedRows, the rows of Ids contains the ids to be looked up in W; +when Ids's type is Tensor, this tensor contains the ids to be looked up in W +and it must be a column vector with rank = 2 while the 2nd dimension size must be 1, +at this time, Ids can carry the LoD (Level of Details) information, or not, and +the output only shares the LoD information with input Ids. -The input Ids can carry the LoD (Level of Details) information, -or not. And the output only shares the LoD information with input Ids. )DOC"); } diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index f314fdbbff..6d81fccd20 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -75,22 +75,22 @@ class LookupTableCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* table_t = context.Input("W"); int64_t padding_idx = context.Attr("padding_idx"); - auto* ids_var = context.InputVar("Ids"); // int tensor + auto* ids_var = context.InputVar("Ids"); + Tensor* output_t = context.Output("Out"); int64_t* ids; int64_t K; - framework::Tensor* output_t; - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out"); // float tensor ids = const_cast(ids_t->data()); K = ids_t->numel(); } else if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out")->mutable_value(); ids = const_cast(ids_t->rows().CUDAData(context.GetPlace())); K = ids_t->rows().size(); output_t->Resize({K, table_t->dims()[1]}); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 4495b4e9e2..c92ce78eef 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -30,23 +30,23 @@ template class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* table_t = context.Input("W"); // float tensor - auto* ids_var = context.InputVar("Ids"); // int tensor + auto* table_t = context.Input("W"); + auto* ids_var = context.InputVar("Ids"); + Tensor* output_t = context.Output("Out"); int64_t* ids; int64_t ids_numel; - Tensor* output_t; - // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. - // Maybe near future we will add concat_rows op. + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out"); ids = const_cast(ids_t->data()); ids_numel = ids_t->numel(); } else if (ids_var->IsType()) { auto* ids_t = context.Input("Ids"); - output_t = context.Output("Out")->mutable_value(); ids = const_cast(ids_t->rows().data()); ids_numel = ids_t->rows().size(); output_t->Resize({ids_numel, table_t->dims()[1]}); From 6a1fbf5be955ddcd4a786867df0a49cca39d8005 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 13 Mar 2018 20:28:12 +0800 Subject: [PATCH 5/9] move fluid dist design to fluid folder --- .../design/dist_train}/distributed_architecture.md | 2 +- .../design/dist_train}/multi_cpu.md | 0 .../design/dist_train}/parameter_server.md | 0 .../design/dist_train}/src/compiler.graffle | Bin .../design/dist_train}/src/compiler.png | Bin .../design/dist_train}/src/dist-graph.graffle | Bin .../design/dist_train}/src/dist-graph.png | Bin .../src/distributed_architecture.graffle | Bin .../dist_train}/src/distributed_architecture.png | Bin .../design/dist_train}/src/local-graph.graffle | Bin .../design/dist_train}/src/local-graph.png | Bin .../dist_train}/src/local_architecture.graffle | Bin .../design/dist_train}/src/local_architecture.png | Bin .../design/dist_train}/src/multi-threads.graffle | Bin .../src/multi-threads/multi-threads@3x.png | Bin .../src/multi-threads/single-thread@3x.png | Bin .../design/dist_train}/src/paddle-compile.graffle | Bin .../design/dist_train}/src/paddle-compile.png | Bin .../design/dist_train}/src/remote_executor.graffle | Bin .../design/dist_train}/src/remote_executor.png | Bin .../design/dist_train}/src/sparse_update.graffle | Bin .../design/dist_train}/src/sparse_update.png | Bin 22 files changed, 1 insertion(+), 1 deletion(-) rename doc/{design/fluid_dist => fluid/design/dist_train}/distributed_architecture.md (99%) rename doc/{design/fluid_dist => fluid/design/dist_train}/multi_cpu.md (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/parameter_server.md (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/compiler.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/compiler.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/dist-graph.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/dist-graph.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/distributed_architecture.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/distributed_architecture.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/local-graph.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/local-graph.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/local_architecture.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/local_architecture.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/multi-threads.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/multi-threads/multi-threads@3x.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/multi-threads/single-thread@3x.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/paddle-compile.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/paddle-compile.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/remote_executor.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/remote_executor.png (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/sparse_update.graffle (100%) rename doc/{design/fluid_dist => fluid/design/dist_train}/src/sparse_update.png (100%) diff --git a/doc/design/fluid_dist/distributed_architecture.md b/doc/fluid/design/dist_train/distributed_architecture.md similarity index 99% rename from doc/design/fluid_dist/distributed_architecture.md rename to doc/fluid/design/dist_train/distributed_architecture.md index 9368c5780d..b32b00ec25 100644 --- a/doc/design/fluid_dist/distributed_architecture.md +++ b/doc/fluid/design/dist_train/distributed_architecture.md @@ -1,4 +1,4 @@ -# Design Doc: Distributed Training Architecture +# Design Doc: Fluid Distributed Training Architecture ## Abstract diff --git a/doc/design/fluid_dist/multi_cpu.md b/doc/fluid/design/dist_train/multi_cpu.md similarity index 100% rename from doc/design/fluid_dist/multi_cpu.md rename to doc/fluid/design/dist_train/multi_cpu.md diff --git a/doc/design/fluid_dist/parameter_server.md b/doc/fluid/design/dist_train/parameter_server.md similarity index 100% rename from doc/design/fluid_dist/parameter_server.md rename to doc/fluid/design/dist_train/parameter_server.md diff --git a/doc/design/fluid_dist/src/compiler.graffle b/doc/fluid/design/dist_train/src/compiler.graffle similarity index 100% rename from doc/design/fluid_dist/src/compiler.graffle rename to doc/fluid/design/dist_train/src/compiler.graffle diff --git a/doc/design/fluid_dist/src/compiler.png b/doc/fluid/design/dist_train/src/compiler.png similarity index 100% rename from doc/design/fluid_dist/src/compiler.png rename to doc/fluid/design/dist_train/src/compiler.png diff --git a/doc/design/fluid_dist/src/dist-graph.graffle b/doc/fluid/design/dist_train/src/dist-graph.graffle similarity index 100% rename from doc/design/fluid_dist/src/dist-graph.graffle rename to doc/fluid/design/dist_train/src/dist-graph.graffle diff --git a/doc/design/fluid_dist/src/dist-graph.png b/doc/fluid/design/dist_train/src/dist-graph.png similarity index 100% rename from doc/design/fluid_dist/src/dist-graph.png rename to doc/fluid/design/dist_train/src/dist-graph.png diff --git a/doc/design/fluid_dist/src/distributed_architecture.graffle b/doc/fluid/design/dist_train/src/distributed_architecture.graffle similarity index 100% rename from doc/design/fluid_dist/src/distributed_architecture.graffle rename to doc/fluid/design/dist_train/src/distributed_architecture.graffle diff --git a/doc/design/fluid_dist/src/distributed_architecture.png b/doc/fluid/design/dist_train/src/distributed_architecture.png similarity index 100% rename from doc/design/fluid_dist/src/distributed_architecture.png rename to doc/fluid/design/dist_train/src/distributed_architecture.png diff --git a/doc/design/fluid_dist/src/local-graph.graffle b/doc/fluid/design/dist_train/src/local-graph.graffle similarity index 100% rename from doc/design/fluid_dist/src/local-graph.graffle rename to doc/fluid/design/dist_train/src/local-graph.graffle diff --git a/doc/design/fluid_dist/src/local-graph.png b/doc/fluid/design/dist_train/src/local-graph.png similarity index 100% rename from doc/design/fluid_dist/src/local-graph.png rename to doc/fluid/design/dist_train/src/local-graph.png diff --git a/doc/design/fluid_dist/src/local_architecture.graffle b/doc/fluid/design/dist_train/src/local_architecture.graffle similarity index 100% rename from doc/design/fluid_dist/src/local_architecture.graffle rename to doc/fluid/design/dist_train/src/local_architecture.graffle diff --git a/doc/design/fluid_dist/src/local_architecture.png b/doc/fluid/design/dist_train/src/local_architecture.png similarity index 100% rename from doc/design/fluid_dist/src/local_architecture.png rename to doc/fluid/design/dist_train/src/local_architecture.png diff --git a/doc/design/fluid_dist/src/multi-threads.graffle b/doc/fluid/design/dist_train/src/multi-threads.graffle similarity index 100% rename from doc/design/fluid_dist/src/multi-threads.graffle rename to doc/fluid/design/dist_train/src/multi-threads.graffle diff --git a/doc/design/fluid_dist/src/multi-threads/multi-threads@3x.png b/doc/fluid/design/dist_train/src/multi-threads/multi-threads@3x.png similarity index 100% rename from doc/design/fluid_dist/src/multi-threads/multi-threads@3x.png rename to doc/fluid/design/dist_train/src/multi-threads/multi-threads@3x.png diff --git a/doc/design/fluid_dist/src/multi-threads/single-thread@3x.png b/doc/fluid/design/dist_train/src/multi-threads/single-thread@3x.png similarity index 100% rename from doc/design/fluid_dist/src/multi-threads/single-thread@3x.png rename to doc/fluid/design/dist_train/src/multi-threads/single-thread@3x.png diff --git a/doc/design/fluid_dist/src/paddle-compile.graffle b/doc/fluid/design/dist_train/src/paddle-compile.graffle similarity index 100% rename from doc/design/fluid_dist/src/paddle-compile.graffle rename to doc/fluid/design/dist_train/src/paddle-compile.graffle diff --git a/doc/design/fluid_dist/src/paddle-compile.png b/doc/fluid/design/dist_train/src/paddle-compile.png similarity index 100% rename from doc/design/fluid_dist/src/paddle-compile.png rename to doc/fluid/design/dist_train/src/paddle-compile.png diff --git a/doc/design/fluid_dist/src/remote_executor.graffle b/doc/fluid/design/dist_train/src/remote_executor.graffle similarity index 100% rename from doc/design/fluid_dist/src/remote_executor.graffle rename to doc/fluid/design/dist_train/src/remote_executor.graffle diff --git a/doc/design/fluid_dist/src/remote_executor.png b/doc/fluid/design/dist_train/src/remote_executor.png similarity index 100% rename from doc/design/fluid_dist/src/remote_executor.png rename to doc/fluid/design/dist_train/src/remote_executor.png diff --git a/doc/design/fluid_dist/src/sparse_update.graffle b/doc/fluid/design/dist_train/src/sparse_update.graffle similarity index 100% rename from doc/design/fluid_dist/src/sparse_update.graffle rename to doc/fluid/design/dist_train/src/sparse_update.graffle diff --git a/doc/design/fluid_dist/src/sparse_update.png b/doc/fluid/design/dist_train/src/sparse_update.png similarity index 100% rename from doc/design/fluid_dist/src/sparse_update.png rename to doc/fluid/design/dist_train/src/sparse_update.png From a43eee40f71352867868714a55dd9fa1135e368f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 13 Mar 2018 23:08:26 +0800 Subject: [PATCH 6/9] follow comments --- .../tests/unittests/test_lookup_table_op.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 518ef6a1bf..ed920ad388 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -53,18 +53,11 @@ class TestLookupTableIdsIsSelectedRows(OpTest): def check_with_place(self, place): scope = core.Scope() - # create and initialize Grad Variable + # create and initialize Variable height = 10 rows = [0, 4, 4, 7] row_numel = 12 - ids_selected_rows = scope.var('Ids').get_selected_rows() - ids_selected_rows.set_height(height) - ids_selected_rows.set_rows(rows) - np_array = np.ones((len(rows), row_numel)).astype("float32") - ids_tensor = ids_selected_rows.get_tensor() - ids_tensor.set(np_array, place) - # create and initialize W Variable W = scope.var('W').get_tensor() W_array = np.full((height, row_numel), 1.0).astype("float32") @@ -72,20 +65,26 @@ class TestLookupTableIdsIsSelectedRows(OpTest): W_array[i] *= i W.set(W_array, place) + # create and initialize Ids Variable + ids_selected_rows = scope.var('Ids').get_selected_rows() + ids_selected_rows.set_height(len(rows)) + ids_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + ids_tensor = ids_selected_rows.get_tensor() + ids_tensor.set(np_array, place) + + # create Out Variable Out = scope.var('Out').get_selected_rows() - Out_array = np.full((len(rows), row_numel), -1.0).astype("float32") - Out.set_height(height) - Out.set_rows(rows) - Out_tensor = Out.get_tensor() - Out_tensor.set(Out_array, place) - # create and run concat_rows_op operator + # create and run lookup_table operator concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out') concat_rows_op.run(scope, place) - # get and compare result + # get result from Out + Out_tensor = Out.get_tensor() result_array = np.array(Out_tensor) + # all(): return True if all elements of the iterable are true (or if the iterable is empty) for idx, row in enumerate(rows): assert (row == result_array[idx]).all() From 28078969fd2f3c3e4b3108c0f2a1227b429a8ac1 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 13 Mar 2018 15:38:17 -0700 Subject: [PATCH 7/9] Fix the CPP Data Feeding design document (#9033) --- doc/design/cpp_data_feeding.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/doc/design/cpp_data_feeding.md b/doc/design/cpp_data_feeding.md index 22c2a925eb..2cbb0083e6 100644 --- a/doc/design/cpp_data_feeding.md +++ b/doc/design/cpp_data_feeding.md @@ -1,17 +1,17 @@ # C++ Data Feeding -In training with Paddle V2 API, data feeding wholly dependents on Python code. To get rid of the Python environment and achieve the goal of "wrapping the whole training by a while loop op" in Paddle Fluid, a C++ data feeding mechanism is required. +While using Paddle V2 API for Training, data feeding completely depends on the Python code. To get rid of the Python environment and achieve the goal of "wrapping the whole training by a while loop op" in Paddle Fluid, a C++ data feeding mechanism is required. -In this document we show the fundamental design of C++ data feeding process, which includes the data reading, shuffling and batching. +In this document we show the fundamental design of a C++ data feeding process, which includes data reading, shuffling and batching. ## Reader -A new concept named 'Reader' is introduced. `Reader` is a series of inherited classes which can be hold by our `Variable` and they are used to read or process file data. +In order to handle the above mentioned problem, a new concept called 'Reader' is introduced. `Reader` is a series of inherited classes which can be held by our `Variable` and they are used to read or process file data. ### `ReaderBase` -`ReaderBase` is the abstract base class of all readers. It defines the all readers' interfaces. +`ReaderBase` is the abstract base class for all readers. It defines the interface for all readers. ```cpp class ReaderBase { @@ -20,10 +20,10 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } // Read the next batch of data. (A 'batch' can be only one instance) - // If the next batch doesn't exist, the '*out' will be an empty std::vector. + // If the next batch doesn't exist, '*out' will be an empty std::vector. virtual void ReadNext(std::vector* out) = 0; - // Reinitialize the reader and read the file from the begin. + // Reinitialize the reader and read the file from the beginning. virtual void ReInit() = 0; // Get a certain read in data's shape. @@ -42,36 +42,36 @@ class ReaderBase { ### `FileReader` and `DecoratedReader` -These two classes are derived from the `ReaderBase` and will further be derived by respective specific readers. That is to say, in our design, there are two kinds of readers: file readers and decorated readers. A file reader reads from a file of some specific format, and yield only one instance of data at a time. e.g. RecordIO reader, jpg reader, .... A decorated reader takes another reader(both file reader and decorated reader are OK) as its 'underlying reader'. It gets data from its underlying reader, does some process on them(shuffling, or batching), then yields processed data. The output data of a decorated reader can be a single instance or a batch. `ShuffleReader` and `BatchReader` are both decorated readers. +These two classes are derived from the `ReaderBase` and will further be derived by more specific readers. Thus, in our design, there are two kinds of readers: file readers and decorated readers. A file reader reads from a file of some specific format, and yield only one instance of data at a time. For example, RecordIO reader, jpg reader, .... A decorated reader takes another reader(both file reader and decorated reader are OK) as its 'underlying reader'. It gets data from its underlying reader, does some processing on them(shuffling, or batching), then yields processed data. The output data of a decorated reader can be a single instance or a batch. `ShuffleReader` and `BatchReader` are both decorated readers. -All the readers share exactly the same interfaces defined in `ReaderBase`. So they can be decorated for more than one time: We can **shuffle** a reader's outputs and then **batch** the shuffle outputs. The interface consistency also allows related ops use readers without knowing what they are exactly. +All the readers share exactly the same interface as defined in `ReaderBase`. So they can be decorated for more than one time: We can **shuffle** a reader's outputs and then **batch** the shuffle outputs. The interface consistency also allows related ops use readers without knowing what they are exactly. ### `ReaderHolder` -Different readers belong to different class types. It leads to a problem: How can we drop them into `Variable`s and fetch them out by a unified method? For example, if a Variable holds a `BatchReader`, we can not get it by the following code: +Different readers belong to different class types. This leads to a problem: How can we drop them into `Variable`s and fetch them out by a unified method? For example, if a Variable holds a `BatchReader`, we can not get it by the following code: ```cpp var->Get("batch_reader"); ``` -we have to write: +We would have to write: ```cpp var->Get("batch_reader"); ``` -This requires each time getting a reader from a variable we must know the reader's type exactly. It is nearly impossible. +This requires that in order to get a reader from a variable, every time, we must know the reader's type exactly. This is nearly impossible. -To solve this problem, we introduce `ReaderHolder` as a wrapper. It acts as an empty decorator of `ReaderBase`, which erases reader's type. With `ReaderHolder` we are able to fetch all types of readers by `var->Get("...")` and regard the obtained object as a reader. +To solve this problem, we introduce `ReaderHolder` as a wrapper. It acts as an empty decorator of `ReaderBase`, which hides reader's type. With `ReaderHolder` we are able to fetch all types of readers by `var->Get("...")` and regard the obtained object as a reader. ## Related Operators -To create and invoke readers, some now ops are introduced: +To create and invoke readers, some new ops are introduced: ### `CreateReaderOp` -Each reader has its creating op. File readers' creating ops have no input and yield the created file reader as its output. Decorated readers' creating ops take the underlying readers as inputs and then yield new decorated readers. +Each reader has its creation op. File readers' creation ops have no input and yield the created file reader as its output. Decorated readers' creation ops take the underlying readers as inputs and then yield new decorated readers. ### `ReadOp` From 14fe40aaa6e19009f6f0836826e367f2ae5c1dee Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 14 Mar 2018 10:29:39 +0800 Subject: [PATCH 8/9] Refine/nccl (#9009) * "Refine nccl op" * "refine code " * "refine nccl code" --- paddle/fluid/operators/nccl_op.cc | 92 +++++++++--------- paddle/fluid/operators/nccl_op.cu.cc | 139 +++++++++------------------ 2 files changed, 89 insertions(+), 142 deletions(-) diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 329656d26d..5e4ed886b1 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -104,19 +104,38 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { " Input(Communicator) of AllReduce op input should not be NULL"); PADDLE_ENFORCE(ctx->HasOutput("Out"), " Output(Out) of AllReduce op output should not be NULL"); - - auto x_dims = ctx->GetInputsDim("X"); - std::string reduction = ctx->Attrs().Get("reduction"); PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || reduction == "ncclMin" || reduction == "ncclMax"), "invalid reduction."); + auto x_dims = ctx->GetInputsDim("X"); ctx->SetOutputsDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); } }; +// AllReduceOp +class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of AllReduce op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of AllReduce op"); + AddAttr("reduction", + "(string, default 'ncclSum') " + "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") + .SetDefault("ncclSum"); + AddComment(R"DOC( +NCCLAllReduce Operator. + +AllReduce the input tensors. + +)DOC"); + } +}; + // ReduceOp class NCCLReduceOp : public framework::OperatorWithKernel { public: @@ -143,50 +162,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel { } }; -// BcastOp -class NCCLBcastOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - " Input(X) of Bcast op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasInput("Communicator"), - " Input(Communicator) of Bcast op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - " Output(Out) of Bcast op output should not be NULL"); - - int root = ctx->Attrs().Get("root"); - PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set."); - - auto x_dims = ctx->GetInputsDim("X"); - ctx->SetOutputsDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } -}; - -// AllreduceOp -class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { - public: - NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input of AllReduce op"); - AddInput("Communicator", "Communicator for communicating between gpus"); - AddOutput("Out", "The output of AllReduce op"); - AddAttr("reduction", - "(string, default 'ncclSum') " - "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") - .SetDefault("ncclSum"); - AddComment(R"DOC( -NCCLAllReduce Operator. - -AllReduce the input tensors. - -)DOC"); - } -}; - // ReduceOp class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -213,6 +188,29 @@ Reduce the tensors. } }; +// BcastOp +class NCCLBcastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasInput("Communicator"), + " Input(Communicator) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Output(Out) of Bcast op output should not be NULL"); + + int root = ctx->Attrs().Get("root"); + PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set."); + + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + // BcastOp class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { public: diff --git a/paddle/fluid/operators/nccl_op.cu.cc b/paddle/fluid/operators/nccl_op.cu.cc index 683a520e99..4d83a70e73 100644 --- a/paddle/fluid/operators/nccl_op.cu.cc +++ b/paddle/fluid/operators/nccl_op.cu.cc @@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - - auto ins = ctx.MultiInput("X"); - auto outs = ctx.MultiOutput("Out"); - + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto* comm = ctx.Input("Communicator"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t reduction_op_ = ncclSum; + ncclRedOp_t reduction_op_ = ncclSum; if (reduction == "ncclMin") { reduction_op_ = ncclMin; } else if (reduction == "ncclMax") { @@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel { } else { PADDLE_THROW("Invalid reduction. default ncclSum."); } - - auto* comm = ctx.Input("Communicator"); - - auto stream = ctx.cuda_device_context().stream(); - // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); - - for (size_t i = 0; i < ins.size(); ++i) { - VLOG(1) << "gpu : " - << " invoke allreduce. send " << ins[i]->numel() << " recv " - << outs[i]->numel(); - - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), - outs[i]->numel(), NCCLTypeWrapper::type, reduction_op_, - comm->comms().at(idx), stream)); - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " - << " finished allreduce. send " << ins[i]->numel() << " recv " - << outs[i]->numel(); - } + VLOG(3) << "gpu : " + << " invoke allreduce. send " << x->numel() << " recv " + << out->numel(); + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + x->data(), out->mutable_data(ctx.GetPlace()), out->numel(), + NCCLTypeWrapper::type, reduction_op_, comm->comms().at(idx), + ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " + << " finished allreduce. send " << x->numel() << " recv " + << out->numel(); } }; @@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - - auto ins = ctx.MultiInput("X"); // x0, x1, x2 - auto outs = ctx.MultiOutput("Out"); - + auto x = ctx.Input("X"); // x0, x1, x2 + auto out = ctx.Output("Out"); + auto* comm = ctx.Input("Communicator"); + int root = ctx.Attr("root"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t reduction_op_ = ncclSum; + ncclRedOp_t reduction_op_ = ncclSum; if (reduction == "ncclMin") { reduction_op_ = ncclMin; } else if (reduction == "ncclMax") { @@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel { } else { PADDLE_THROW("Invalid reduction. default ncclSum."); } - - int root = ctx.Attr("root"); - auto* comm = ctx.Input("Communicator"); - - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); - - auto ins_names = ctx.Inputs("X"); - std::hash hasher; - for (size_t i = 0; i < ins.size(); ++i) { - if (root == platform::kInvalidGPUId) { - root = hasher(ins_names[i]) % comm->comms().size(); - } - T* recvbuffer = nullptr; - if (root == gpu_id) { - recvbuffer = outs[i]->mutable_data(ctx.GetPlace()); - } - - VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send " - << ins[i]->numel() << " recv " << outs[i]->numel(); - - PADDLE_ENFORCE(platform::dynload::ncclReduce( - ins[i]->data(), recvbuffer, ins[i]->numel(), - NCCLTypeWrapper::type, reduction_op_, root, comm->comms().at(idx), - stream)); - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " << gpu_id << " finished reduce. send " - << ins[i]->numel() << " recv " << outs[i]->numel(); + T* recvbuffer = nullptr; + if (root == gpu_id) { + recvbuffer = out->mutable_data(ctx.GetPlace()); } + VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel() + << " recv " << out->numel(); + PADDLE_ENFORCE(platform::dynload::ncclReduce( + x->data(), recvbuffer, x->numel(), NCCLTypeWrapper::type, + reduction_op_, root, comm->comms().at(idx), + ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " << gpu_id << " finished reduce. send " << x->numel() + << " recv " << out->numel(); } }; @@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - int root = ctx.Attr("root"); - auto* comm = ctx.Input("Communicator"); - - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); - if (idx == root) { - auto ins = ctx.MultiInput("X"); - for (size_t i = 0; i < ins.size(); ++i) { - VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send " - << ins[i]->numel(); - - VLOG(1) << " before ncclBcast"; - PADDLE_ENFORCE(platform::dynload::ncclBcast( - (void*)ins[i]->data(), ins[i]->numel(), NCCLTypeWrapper::type, - root, comm->comms().at(idx), stream)); - VLOG(1) << " after ncclBcast"; - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " << gpu_id << " finished Bcast."; - } + auto* x = ctx.Input("X"); + VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. send " << x->numel(); + PADDLE_ENFORCE(platform::dynload::ncclBcast( + (void*)x->data(), x->numel(), NCCLTypeWrapper::type, root, + comm->comms().at(idx), ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " << gpu_id << " finished Bcast."; } else { - auto outs = ctx.MultiOutput("Out"); - for (size_t i = 0; i < outs.size(); ++i) { - VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " - << framework::product(outs[i]->dims()); - - PADDLE_ENFORCE(platform::dynload::ncclBcast( - outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), - NCCLTypeWrapper::type, root, comm->comms().at(idx), stream)); - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv " - << outs[i]->numel(); - } + auto* out = ctx.Output("Out"); + VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " + << framework::product(out->dims()); + PADDLE_ENFORCE(platform::dynload::ncclBcast( + out->mutable_data(ctx.GetPlace()), out->numel(), + NCCLTypeWrapper::type, root, comm->comms().at(idx), + ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " << gpu_id << " finished Bcast. recv " << out->numel(); } } }; From d13ce3587559c5553f05d75789269a0dff49734f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Wed, 14 Mar 2018 10:38:01 +0800 Subject: [PATCH 9/9] Feature/send recv can now retry (#9027) --- paddle/fluid/operators/detail/grpc_client.cc | 18 ++++++++-- paddle/fluid/operators/detail/grpc_client.h | 36 +++++++++++++------ paddle/fluid/operators/detail/grpc_server.cc | 21 +++++++---- paddle/fluid/operators/detail/grpc_server.h | 2 +- .../fluid/operators/detail/sendrecvop_utils.h | 1 + paddle/fluid/operators/listen_and_serv_op.cc | 4 +-- paddle/fluid/operators/send_op.cc | 6 ++++ python/paddle/fluid/distribute_transpiler.py | 20 +++++++++-- 8 files changed, 83 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 7266f32764..ddeeebec58 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { +void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); @@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, (void*)s); req_count_++; +} - return true; +void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { + const auto ch = GetChannel(ep); + FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(FETCH_BARRIER_MESSAGE); + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + req_count_++; } bool RPCClient::Wait() { @@ -154,7 +164,7 @@ bool RPCClient::Proceed() { PADDLE_ENFORCE(tag); // TODO(gongwb): add more retries. - ClientBase* c = static_cast(tag); + BaseProcessor* c = static_cast(tag); if (!c->status_.ok()) { LOG(ERROR) << "proc param error:" << c->var_h_.String() << " grpc error:" << c->status_.error_message(); @@ -174,6 +184,8 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { } grpc::ChannelArguments args; + args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000); + args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetMaxSendMessageSize(std::numeric_limits::max()); args.SetMaxReceiveMessageSize(std::numeric_limits::max()); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 669838810d..f520367dd9 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -52,14 +52,14 @@ struct VarHandle { void ProcGetResponse(const VarHandle& var_h, const sendrecv::VariableMessage& msg); -class ClientBase { +class BaseProcessor { public: - explicit ClientBase(std::shared_ptr ch) { + explicit BaseProcessor(std::shared_ptr ch) { stub_ = sendrecv::SendRecvService::NewStub(ch); context_ = NULL; } - virtual ~ClientBase() {} + virtual ~BaseProcessor() {} virtual void Prepare(const VarHandle& var_info, int64_t time_out) { context_.reset(new grpc::ClientContext()); @@ -91,9 +91,10 @@ class ClientBase { typedef std::function RequestSendCallBack; -class SendProcessor : public ClientBase { +class SendProcessor : public BaseProcessor { public: - explicit SendProcessor(std::shared_ptr ch) : ClientBase(ch) {} + explicit SendProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} virtual ~SendProcessor() {} @@ -110,9 +111,10 @@ class SendProcessor : public ClientBase { typedef std::function RequestGetCallBack; -class GetProcessor : public ClientBase { +class GetProcessor : public BaseProcessor { public: - explicit GetProcessor(std::shared_ptr ch) : ClientBase(ch) {} + explicit GetProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} virtual ~GetProcessor() {} @@ -126,10 +128,10 @@ class GetProcessor : public ClientBase { RequestGetCallBack response_call_back_ = ProcGetResponse; }; -class BatchBarrierProcessor : public ClientBase { +class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) - : ClientBase(ch) {} + : BaseProcessor(ch) {} virtual ~BatchBarrierProcessor() {} @@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase { sendrecv::VoidMessage reply_; }; +class FetchBarrierProcessor : public BaseProcessor { + public: + explicit FetchBarrierProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} + + virtual ~FetchBarrierProcessor() {} + + virtual void Process() {} + sendrecv::VariableMessage reply_; +}; + class RPCClient { public: bool AsyncSendVariable(const std::string& ep, @@ -151,7 +164,10 @@ class RPCClient { const std::string& var_name, int64_t time_out = 600 * 1000); - bool AsyncSendBatchBarrier(const std::string& ep, + void AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = 600 * 1000); + + void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = 600 * 1000); bool Wait(); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 2a56751661..8fff430cc4 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -84,7 +84,7 @@ class RequestGet final : public RequestBase { explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, grpc::ServerCompletionQueue* cq, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - SimpleBlockQueue* queue) + SimpleBlockQueue* queue) : RequestBase(service, cq), responder_(&ctx_), scope_(scope), @@ -101,11 +101,16 @@ class RequestGet final : public RequestBase { // proc request. std::string var_name = request_.varname(); auto* var = scope_->FindVar(var_name); - SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + if (var_name != FETCH_BARRIER_MESSAGE) { + SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + } // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); status_ = FINISH; - queue_->Push('c'); + MessageWithName msg_with_name = + // request name reply + std::make_pair(var_name, std::move(reply_)); + queue_->Push(msg_with_name); } protected: @@ -114,12 +119,16 @@ class RequestGet final : public RequestBase { ServerAsyncResponseWriter responder_; framework::Scope* scope_; const platform::DeviceContext* dev_ctx_; - SimpleBlockQueue* queue_; + SimpleBlockQueue* queue_; }; void AsyncGRPCServer::WaitClientGet(int count) { - for (int i = 0; i < count; ++i) { - var_get_queue_.Pop(); + int fetch_barriers = 0; + while (fetch_barriers < count) { + auto msg = var_get_queue_.Pop(); + if (msg.first == FETCH_BARRIER_MESSAGE) { + fetch_barriers++; + } } } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index e9402ff6aa..b6666bcf96 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { const platform::DeviceContext *dev_ctx_; // received variable from RPC, operators fetch variable from this queue. SimpleBlockQueue var_recv_queue_; - SimpleBlockQueue var_get_queue_; + SimpleBlockQueue var_get_queue_; // condition of the sub program std::mutex barrier_mutex_; diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index 5208091e54..4fa6aefd3e 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -32,6 +32,7 @@ namespace detail { #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" +#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" typedef void (*DestroyCallback)(void*); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 8e9923c87c..4253300788 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase { } } if (exit_flag) { - rpc_service_->ShutDown(); rpc_service_->SetCond(1); + rpc_service_->ShutDown(); break; } try { @@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase { } rpc_service_->SetCond(1); // FIXME(typhoonzero): use another condition to sync wait clients get. - rpc_service_->WaitClientGet(ins.size()); + rpc_service_->WaitClientGet(fan_in); sparse_vars.clear(); } // while(true) } diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 8fdd08eae6..443f40e803 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase { rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } PADDLE_ENFORCE(rpc_client->Wait()); + // tell pservers that current trainer have called fetch + for (auto& ep : endpoints) { + VLOG(3) << "send fetch barrier, ep: " << ep; + rpc_client->AsyncSendFetchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } } }; diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index bb2ce4d45d..3d3a6c116e 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -250,6 +250,8 @@ class DistributeTranspiler: def get_trainer_program(self): # remove optimize ops and add a send op to main_program self.program.global_block().delete_ops(self.optimize_ops) + # FIXME(typhoonzero): serialize once will fix error occurs when clone. + self.program.__str__() return self.program def get_pserver_program(self, endpoint): @@ -309,7 +311,8 @@ class DistributeTranspiler: for _, opt_op in enumerate(opt_op_on_pserver): if ufind.is_connected(op, opt_op): if self._is_opt_op(op): - self._append_pserver_ops(optimize_block, op, endpoint) + self._append_pserver_ops(optimize_block, op, endpoint, + default_main_program()) else: self._append_pserver_non_opt_ops(optimize_block, op) break @@ -520,7 +523,8 @@ class DistributeTranspiler: orig_var_name = varname[:suff_idx] return orig_var_name - def _append_pserver_ops(self, optimize_block, opt_op, endpoint): + def _append_pserver_ops(self, optimize_block, opt_op, endpoint, + origin_program): program = optimize_block.program pserver_block = program.global_block() new_inputs = dict() @@ -576,7 +580,17 @@ class DistributeTranspiler: elif key == "LearningRate": # leraning rate variable has already be created by non-optimize op, # don't create it once again. - new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] + lr_varname = opt_op.input(key)[0] + if pserver_block.vars.has_key(lr_varname): + new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] + else: + origin_var = origin_program.global_block().vars[lr_varname] + tmpvar = pserver_block.create_var( + name=origin_var.name, + persistable=origin_var.persistable, + dtype=origin_var.dtype, + shape=origin_var.shape) + new_inputs[key] = tmpvar for key in opt_op.input_names: new_shape = None