From b1e5183627e4aad56400b620342a55434321d544 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 10 May 2018 11:22:46 +0800 Subject: [PATCH 01/17] overlap sendop and backward ops --- paddle/fluid/operators/recv_op.cc | 18 ++- python/paddle/fluid/transpiler/__init__.py | 3 +- .../fluid/transpiler/distribute_transpiler.py | 123 +++++++++++------- .../fluid/transpiler/distributed_splitter.py | 57 -------- .../paddle/fluid/transpiler/ps_dispatcher.py | 78 +++++++++++ 5 files changed, 167 insertions(+), 112 deletions(-) delete mode 100644 python/paddle/fluid/transpiler/distributed_splitter.py create mode 100644 python/paddle/fluid/transpiler/ps_dispatcher.py diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index a4dcf704a6..aeb93c9981 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -36,19 +36,22 @@ class RecvOp : public framework::OperatorBase { const platform::Place& place) const override { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); + auto client_var_name = Output("RPCClient"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), + "Can not find variable '%s' in the scope.", + client_var_name); + auto* client_var = scope.FindVar(client_var_name); + detail::RPCClient* rpc_client = client_var->GetMutable(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i]; - client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; + rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - PADDLE_ENFORCE(client_.Wait()); + PADDLE_ENFORCE(rpc_client->Wait()); } - - private: - mutable detail::RPCClient client_; }; class RecvOpMaker : public framework::OpProtoAndCheckerMaker { @@ -56,6 +59,9 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker { RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); + AddOutput("RPCClient", + "(RPCClient) The RPC client object which is" + "initialized at most once."); AddComment(R"DOC( Recv operator diff --git a/python/paddle/fluid/transpiler/__init__.py b/python/paddle/fluid/transpiler/__init__.py index 6d3c1b947f..f21e4dc033 100644 --- a/python/paddle/fluid/transpiler/__init__.py +++ b/python/paddle/fluid/transpiler/__init__.py @@ -15,8 +15,9 @@ from distribute_transpiler import DistributeTranspiler from inference_transpiler import InferenceTranspiler from memory_optimization_transpiler import memory_optimize, release_memory from distribute_transpiler_simple import SimpleDistributeTranspiler +from ps_dispatcher import HashName, RoundRobin __all__ = [ "DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler", - "memory_optimize", "release_memory" + "memory_optimize", "release_memory", "HashName", "RoundRobin" ] diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 640ac9f085..05ffdefe05 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -17,7 +17,8 @@ from __future__ import print_function import math import distributed_splitter as splitter -from .. import core +from ps_dispatcher import RoundRobin, HashName, PSDispatcher +from .. import core, framework from ..framework import Program, default_main_program, Variable, Parameter LOOKUP_TABLE_TYPE = "lookup_table" @@ -144,13 +145,27 @@ def delete_ops(block, ops): block.program.sync_with_cpp() +def find_op_by_input_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.input_arg_names: + return index + return -1 + + +def find_op_by_output_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.output_arg_names: + return index + return -1 + + class DistributeTranspiler: def transpile(self, trainer_id, program=None, pservers="127.0.0.1:6174", trainers=1, - split_method=splitter.round_robin, + split_method=RoundRobin, sync_mode=True): """ Transpile the program to distributed data-parallelism programs. @@ -184,14 +199,14 @@ class DistributeTranspiler: :type pservers: string :param trainers: total number of workers/trainers in the job :type trainers: int - :param split_method: A function to determin how to split variables - to different servers equally. - :type split_method: function + :param split_method: A instance to determin how to dispatch variable + blocks to different servers equally. + :type split_method: A instance based on PSDispatcher class. :type sync_mode: boolean default True :param sync_mode: if sync_mode is set True, it means that dist transpiler will transpile the program into sync_mode pserver and trainer program. """ - assert (callable(split_method)) + assert (split_method.__bases__[0] == PSDispatcher) if program is None: program = default_main_program() self.origin_program = program @@ -204,6 +219,7 @@ class DistributeTranspiler: pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints self.optimize_ops, params_grads = self._get_optimize_pass() + ps_dispatcher = split_method(pserver_endpoints) # process lookup_table_op # 1. check all lookup_table_op is distributed @@ -268,56 +284,67 @@ class DistributeTranspiler: grad_var_mapping = self._append_split_op(program, grad_blocks) param_var_mapping = self._create_vars_from_blocklist(program, param_blocks) - # step3: Add gradients as send op inputs and parameters as send - # op outputs. - send_inputs = [] - send_outputs = [] - for b in grad_blocks: # append by order - varname, block_id, _ = b.split(":") - send_inputs.append(grad_var_mapping[varname][int(block_id)]) - for b in param_blocks: - varname, block_id, _ = b.split(":") - send_outputs.append(param_var_mapping[varname][int(block_id)]) - # let send_op know which endpoint to send which var to, eplist has the same - # order as send_inputs. - eplist = split_method(send_inputs, pserver_endpoints) - # create mapping of endpoint -> split var to create pserver side program - self.param_grad_ep_mapping = dict() - for i, ep in enumerate(eplist): - param = send_outputs[i] - grad = send_inputs[i] - if not self.param_grad_ep_mapping.has_key(ep): - self.param_grad_ep_mapping[ep] = {"params": [], "grads": []} - self.param_grad_ep_mapping[ep]["params"].append(param) - self.param_grad_ep_mapping[ep]["grads"].append(grad) - rpc_client_var = program.global_block().create_var( name=RPC_CLIENT_VAR_NAME, persistable=True, type=core.VarDesc.VarType.RAW) - # create send_op + # step 3: transpile trainer side program, insert recv op and send op. + + # create mapping of endpoint -> split var to create pserver side program + self.param_grad_ep_mapping = dict() + [ + self.param_grad_ep_mapping.update({ + ep: { + "params": [], + "grads": [] + } + }) for ep in self.pserver_endpoints + ] + + # step 3.1: insert send op to send gradient vars to parameter servers + ps_dispatcher.reset() + for varname, send_vars in grad_var_mapping.items(): + index = find_op_by_output_arg(program.global_block(), varname) + eplist = ps_dispatcher.dispatch(send_vars) + program.global_block().insert_op( + index=index, + type="send_vars", + inputs={"X": send_vars}, + outputs={"RPCClient": rpc_client_var}, + attrs={"epmap": eplist}) + + if self.sync_mode: + program.global_block().append_op( + type="send_barrier", + inputs={}, + outputs={"RPCClient": rpc_client_var}, + attrs={"endpoints": pserver_endpoints}) + + # step 3.2: insert recv op to receive parameters from parameter server + ps_dispatcher.reset() + recv_vars = [] + for b in param_blocks: + varname, block_id, _ = b.split(":") + recv_vars.append(param_var_mapping[varname][int(block_id)]) + for b in grad_blocks: + varname, block_id, _ = b.split(":") + send_vars.append(grad_var_mapping[varname][int(block_id)]) + + eplist = ps_dispatcher.dispatch(recv_vars) + + for i, ep in enumerate(eplist): + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + program.global_block().append_op( - type="send", - inputs={"X": send_inputs}, - outputs={"Out": send_outputs, + type="recv", + inputs={}, + outputs={"Out": recv_vars, "RPCClient": rpc_client_var}, - attrs={ - "endpoints": pserver_endpoints, - "epmap": eplist, - "sync_mode": self.sync_mode - }) - # step4: Concat the parameters splits together after recv. - for varname, splited_var in param_var_mapping.iteritems(): - if len(splited_var) <= 1: - continue - orig_param = program.global_block().vars[varname] - program.global_block().append_op( - type="concat", - inputs={"X": splited_var}, - outputs={"Out": [orig_param]}, - attrs={"axis": 0}) + attrs={"epmap": eplist}) + # TODO(Yancey1989): check dist lookup table if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, eplist) diff --git a/python/paddle/fluid/transpiler/distributed_splitter.py b/python/paddle/fluid/transpiler/distributed_splitter.py deleted file mode 100644 index 060c1df8ad..0000000000 --- a/python/paddle/fluid/transpiler/distributed_splitter.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def hash_name(varlist, pserver_endpoints): - """ - hash variable names to several endpoints. - - Args: - varlist(list): a list of Variables - - Returns(dict): a map of pserver endpoint -> varname - """ - - def _hash_block(block_str, total): - return hash(block_str) % total - - eplist = [] - for var in varlist: - server_id = _hash_block(var.name(), len(pserver_endpoints)) - server_for_param = pserver_endpoints[server_id] - eplist.append(server_for_param) - return eplist - - -def round_robin(varlist, pserver_endpoints): - """ - Distribute variables to several endpoints. - Args: - varlist(list): a list of variables - pserver_endpoints(list): a list of pserver endpoints - - Returns(list[int]): the endpoint for each variable - """ - assert (len(varlist) >= len(pserver_endpoints)) - - eplist = [] - pserver_idx = 0 - for var in varlist: - server_for_param = pserver_endpoints[pserver_idx] - eplist.append(server_for_param) - - pserver_idx += 1 - if pserver_idx >= len(pserver_endpoints): - pserver_idx = 0 - return eplist diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py new file mode 100644 index 0000000000..dffe66998a --- /dev/null +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -0,0 +1,78 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class PSDispatcher(object): + """ + DistributedSpliter is the base class for dispatching vars + into different pserver instance. + You need to implement the `dispatch` inferface. + """ + + def __init__(self, pserver_endpoints): + self._eps = pserver_endpoints + self._step = 0 + + @property + def eps(self): + return self._eps + + def reset(self): + self._step = 0 + + def dispatch(self, varlist): + """ + :param varlist: a list of Variables + :return: a map of pserver endpoint -> varname + """ + AssertionError("Interface has not been implemented.") + + +class HashName(PSDispatcher): + """ + Hash variable names to servral endpoints + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def _hash_block(self, block_str, total): + return hash(block_str) % total + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_id = self._hash_block(var.name(), len(self._eps)) + server_for_param = self._eps[server_id] + eplist.append(server_for_param) + return eplist + + +class RoundRobin(PSDispatcher): + """ + Distribute variables to serveral endpoints. + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_for_param = self._eps[self._step] + eplist.append(server_for_param) + self._step += 1 + if self._step >= len(self._eps): + self._step = 0 + return eplist From 6e5635fd1e7a3e0d308b4e7b98ddd48ea4b1fcc4 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 10 May 2018 19:27:25 +0800 Subject: [PATCH 02/17] update --- .../fluid/transpiler/distribute_transpiler.py | 79 ++++++++++++++----- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 05ffdefe05..7b8bf17f27 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -279,11 +279,20 @@ class DistributeTranspiler: grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) + assert (len(grad_blocks) == len(param_blocks)) # step2: Create new vars for the parameters and gradients blocks and # add ops to do the split. - grad_var_mapping = self._append_split_op(program, grad_blocks) param_var_mapping = self._create_vars_from_blocklist(program, param_blocks) + grad_var_mapping = self._create_vars_from_blocklist( + program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) + grad_param_mapping = dict() + for g, p in zip(grad_blocks, param_blocks): + g_name, g_bid, _ = g.split(":") + p_name, p_bid, _ = p.split(":") + grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ + param_var_mapping[p_name][int(p_bid)] + rpc_client_var = program.global_block().create_var( name=RPC_CLIENT_VAR_NAME, persistable=True, @@ -304,15 +313,21 @@ class DistributeTranspiler: # step 3.1: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() - for varname, send_vars in grad_var_mapping.items(): + send_vars = [] + for varname, splited_vars in grad_var_mapping.items(): index = find_op_by_output_arg(program.global_block(), varname) - eplist = ps_dispatcher.dispatch(send_vars) + eplist = ps_dispatcher.dispatch(splited_vars) + if len(splited_vars) > 1: + self._insert_split_op(program, varname, splited_vars) + index += 1 program.global_block().insert_op( - index=index, + index=index + 1, type="send_vars", - inputs={"X": send_vars}, + inputs={"X": splited_vars}, outputs={"RPCClient": rpc_client_var}, attrs={"epmap": eplist}) + for _, var in enumerate(splited_vars): + send_vars.append(var) if self.sync_mode: program.global_block().append_op( @@ -322,21 +337,12 @@ class DistributeTranspiler: attrs={"endpoints": pserver_endpoints}) # step 3.2: insert recv op to receive parameters from parameter server - ps_dispatcher.reset() recv_vars = [] - for b in param_blocks: - varname, block_id, _ = b.split(":") - recv_vars.append(param_var_mapping[varname][int(block_id)]) - for b in grad_blocks: - varname, block_id, _ = b.split(":") - send_vars.append(grad_var_mapping[varname][int(block_id)]) - + for _, var in enumerate(send_vars): + recv_vars.append(grad_param_mapping[var]) + ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) - for i, ep in enumerate(eplist): - self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) - self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) - program.global_block().append_op( type="recv", inputs={}, @@ -344,6 +350,10 @@ class DistributeTranspiler: "RPCClient": rpc_client_var}, attrs={"epmap": eplist}) + for i, ep in enumerate(eplist): + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + # TODO(Yancey1989): check dist lookup table if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, @@ -848,6 +858,34 @@ class DistributeTranspiler: lod_level=var.lod_level, persistable=persistable) + def _insert_split_op(self, program, orig_varname, splited_vars): + orig_var = program.global_block().vars[orig_varname] + index = find_op_by_output_arg(program.global_block(), orig_varname) + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: + height_sections = [] + for v in splited_vars: + height_sections.append(v.shape[0]) + program.global_block().insert_op( + index=index + 1, + type="split_selected_rows", + inputs={"X": orig_var}, + outputs={"Out": splited_vars}, + attrs={"height_sections": height_sections}) + elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: + sections = [] + for v in splited_vars: + sections.append(v.shape[0]) + program.global_block().insert_op( + index=index + 1, + type="split_byref", + inputs={"X": orig_var}, + outputs={"Out": splited_vars}, + attrs={"sections": sections} # assume split evenly + ) + else: + AssertionError("Variable type should be in set " + "[LOD_TENSOR, SELECTED_ROWS]") + def _append_split_op(self, program, gradblocks): # Split variables that need to be split and append respective ops add_suffix = False @@ -860,11 +898,13 @@ class DistributeTranspiler: if len(splited_vars) <= 1: continue orig_var = program.global_block().vars[varname] + index = find_op_by_output_arg(program.global_block(), orig_var.name) if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) - program.global_block().append_op( + program.global_block().insert_op( + index=index + 1, type="split_selected_rows", inputs={"X": orig_var}, outputs={"Out": splited_vars}, @@ -873,7 +913,8 @@ class DistributeTranspiler: sections = [] for v in splited_vars: sections.append(v.shape[0]) - program.global_block().append_op( + program.global_block().insert_op( + index=index + 1, type="split_byref", inputs={"X": orig_var}, outputs={"Out": splited_vars}, From 315e44acee06ca1933f6f77ae8386afb5544cdfb Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 15 May 2018 16:23:03 +0800 Subject: [PATCH 03/17] add fetch_barrier_op --- paddle/fluid/framework/executor.cc | 2 +- paddle/fluid/operators/CMakeLists.txt | 4 +- paddle/fluid/operators/fetch_barrier_op.cc | 101 ++++++++++++++++++ .../fluid/transpiler/distribute_transpiler.py | 67 ++++-------- 4 files changed, 126 insertions(+), 48 deletions(-) create mode 100644 paddle/fluid/operators/fetch_barrier_op.cc diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index ce91d7a826..d411ae3466 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -353,7 +353,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, scope->DeleteScope(local_scope); } else { // Delete the local scopes created in operators. - scope->DropKids(); + // scope->DropKids(); } if (FLAGS_benchmark) { VLOG(2) << "-------------------------------------------------------"; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c14a2b7786..39c20bb211 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -199,11 +199,13 @@ if(WITH_DISTRIBUTE) op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) + op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) else() - set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) + set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op) endif() op_library(cross_entropy_op DEPS cross_entropy) diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc new file mode 100644 index 0000000000..3dfdd135ee --- /dev/null +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#include "paddle/fluid/operators/detail/grpc_client.h" + +namespace paddle { +namespace operators { + +class FetchBarrierOp : public framework::OperatorBase { + public: + FetchBarrierOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + std::vector eps = Attr>("endpoints"); + + auto client_var_name = Output("RPCClient"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), + "Can not find variable '%s' in the scope.", + client_var_name); + auto* client_var = scope.FindVar(client_var_name); + detail::RPCClient* rpc_client = client_var->GetMutable(); + + PADDLE_ENFORCE(rpc_client->Wait()); + + for (auto& ep : eps) { + VLOG(3) << "fetch barrier, ep: " << ep; + rpc_client->AsyncSendFetchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); + } +}; + +class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddOutput("RPCClient", + "(RPCClient) The RPC client object which is" + "initialized at most once."); + AddComment(R"DOC( +SendBarrier operator + +This operator will send a send barrier signal to list_and_serv op, so that +the Parameter Server would knew all variables have been sent. +)DOC"); + + AddAttr>("endpoints", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints to send variables to.") + .SetDefault({"127.0.0.1:6164"}); + } +}; + +class FetchBarrierOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output("RPCClient").front(); + auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class FetchBarrierOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp, + paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker, + ops::FetchBarrierOpVarTypeInference, + ops::FetchBarrierOpShapeInference); diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 8668edab94..5e90f3f64a 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -315,12 +315,22 @@ class DistributeTranspiler: # step 3.1: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() send_vars = [] - for varname, splited_vars in grad_var_mapping.items(): - index = find_op_by_output_arg(program.global_block(), varname) + for orig_varname, splited_vars in grad_var_mapping.items(): eplist = ps_dispatcher.dispatch(splited_vars) - if len(splited_vars) > 1: - self._insert_split_op(program, varname, splited_vars) + if len(splited_vars) == 1: + orig_varname = splited_vars[0].name + index = find_op_by_output_arg(program.global_block(), + orig_varname) + elif len(splited_vars) > 1: + orig_var = program.global_block().vars[orig_varname] + index = find_op_by_output_arg(program.global_block(), + orig_varname) + self._insert_split_op(program, orig_var, index, splited_vars) index += 1 + else: + AssertionError("Can not insert the send op by original " + "variable name :", orig_varname) + program.global_block().insert_op( index=index + 1, type="send_vars", @@ -351,6 +361,12 @@ class DistributeTranspiler: "RPCClient": rpc_client_var}, attrs={"epmap": eplist}) + program.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={"RPCClient": rpc_client_var}, + attrs={"endpoints": pserver_endpoints}) + for i, ep in enumerate(eplist): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) @@ -859,9 +875,7 @@ class DistributeTranspiler: lod_level=var.lod_level, persistable=persistable) - def _insert_split_op(self, program, orig_varname, splited_vars): - orig_var = program.global_block().vars[orig_varname] - index = find_op_by_output_arg(program.global_block(), orig_varname) + def _insert_split_op(self, program, orig_var, index, splited_vars): if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: @@ -887,45 +901,6 @@ class DistributeTranspiler: AssertionError("Variable type should be in set " "[LOD_TENSOR, SELECTED_ROWS]") - def _append_split_op(self, program, gradblocks): - # Split variables that need to be split and append respective ops - add_suffix = False - if self.trainer_num > 1: - add_suffix = True - var_mapping = self._create_vars_from_blocklist( - program, gradblocks, add_trainer_suffix=add_suffix) - for varname, splited_vars in var_mapping.iteritems(): - # variable that don't need to split have empty splited_vars - if len(splited_vars) <= 1: - continue - orig_var = program.global_block().vars[varname] - index = find_op_by_output_arg(program.global_block(), orig_var.name) - if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: - height_sections = [] - for v in splited_vars: - height_sections.append(v.shape[0]) - program.global_block().insert_op( - index=index + 1, - type="split_selected_rows", - inputs={"X": orig_var}, - outputs={"Out": splited_vars}, - attrs={"height_sections": height_sections}) - elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: - sections = [] - for v in splited_vars: - sections.append(v.shape[0]) - program.global_block().insert_op( - index=index + 1, - type="split_byref", - inputs={"X": orig_var}, - outputs={"Out": splited_vars}, - attrs={"sections": sections} # assume split evenly - ) - else: - AssertionError("Variable type should be in set " - "[LOD_TENSOR, SELECTED_ROWS]") - return var_mapping - def _get_optimizer_input_shape(self, op_type, varkey, orig_shape, param_shape): """ From eb2e68ee6bcf3008291ea64b52ab1dff6f2bc52c Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 15 May 2018 16:40:18 +0800 Subject: [PATCH 04/17] revert executor run --- paddle/fluid/framework/executor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index d411ae3466..ce91d7a826 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -353,7 +353,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, scope->DeleteScope(local_scope); } else { // Delete the local scopes created in operators. - // scope->DropKids(); + scope->DropKids(); } if (FLAGS_benchmark) { VLOG(2) << "-------------------------------------------------------"; From 62af10d440473867c42cd9a8e2e5a3b8d854d500 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 21 May 2018 18:52:32 +0800 Subject: [PATCH 05/17] support multiple devices --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../details/multi_devices_graph_builder.cc | 60 +++++++++++++++---- .../details/multi_devices_graph_builder.h | 5 ++ .../fluid/framework/details/rpc_op_handle.cc | 50 ++++++++++++++++ .../fluid/framework/details/rpc_op_handle.h | 52 ++++++++++++++++ paddle/fluid/framework/variable.h | 3 + paddle/fluid/operators/detail/grpc_client.cc | 26 ++++---- paddle/fluid/operators/detail/grpc_client.h | 5 +- paddle/fluid/operators/fetch_barrier_op.cc | 6 ++ paddle/fluid/operators/recv_op.cc | 10 +++- paddle/fluid/operators/send_barrier_op.cc | 5 ++ paddle/fluid/operators/send_recv_util.h | 3 + paddle/fluid/operators/send_vars_op.cc | 6 ++ 13 files changed, 208 insertions(+), 26 deletions(-) create mode 100644 paddle/fluid/framework/details/rpc_op_handle.cc create mode 100644 paddle/fluid/framework/details/rpc_op_handle.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9de44beafb..2c838f4361 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) @@ -26,7 +27,7 @@ endif() cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 45bad58145..50998fb8e0 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include #include #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" +#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/scope.h" @@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, CreateOpOutput(result, op_handle, each_var_name, p, place_id); } } - bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const { if (send_op == nullptr) { @@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, return false; }; - if (op.Type() == "split") { + if (op.Type() == "split" || op.Type() == "split_byref") { return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); } else if (op.Type() == "concat") { return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); @@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, return false; } +bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const { + for (auto &name : op.OutputNames()) { + if (name == "RPCClient") { + return true; + } + } + return false; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { std::unordered_map var_types; @@ -133,10 +143,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send") { - // append send op if program is distributed trainer main program. + if (IsRPCOp(*op)) { + // append rpc op if program is distributed trainer main program. // always use the first device - CreateSendOp(&result, *op); + CreateRPCOp(&result, *op); } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { @@ -203,9 +213,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::ostringstream sout; - PrintGraphviz(*graph, sout); - VLOG(10) << sout.str(); + std::string filename = "/tmp/graph"; + std::ofstream fout(filename); + PrintGraphviz(*graph, fout); } return std::unique_ptr(graph); @@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } -void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, + std::string op_name) const { + for (auto &prev_op : result->ops_) { + if (prev_op->Name() == op_name) { + auto *dep_var = new DummyVarHandle(); + prev_op->AddOutput(dep_var); + result->dep_vars_.emplace(dep_var); + result->ops_.back().get()->AddInput(dep_var); + } + } +} + +void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, + const OpDesc &op) const { auto &p = places_[0]; auto *s = local_scopes_[0]; + VLOG(3) << "create rpc op: " << op.Type(); + result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + if (op.Type() == "send_barrier") { + ConnectOp(result, "send_vars"); + } else if (op.Type() == "recv") { + ConnectOp(result, "send_barrier"); + } else if (op.Type() == "fetch_barrier") { + ConnectOp(result, "recv"); + } else if (op.Type() == "send" || op.Type() == "send_vars") { + // do nothing + } else { + PADDLE_THROW( + "rpc op should be in [send," + "send_vars, send_barrier. recv, fetch_barrier]"); + } + // FIXME(wuyi): send op always copy from GPU 0 - result->ops_.emplace_back(new SendOpHandle(op, s, p)); + // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); // Create inputs for output on original place and no ssa output // is created for send op. CreateOpHandleIOs(result, op, 0); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 4f70852188..45713b0c4f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. */ bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + bool IsRPCOp(const OpDesc &op) const; + + void ConnectOp(SSAGraph *result, std::string op_name) const; + void CreateComputationalOps(SSAGraph *result, const OpDesc &op, size_t num_places) const; diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc new file mode 100644 index 0000000000..03f53421b1 --- /dev/null +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/rpc_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, + const Scope *local_scope, const platform::Place &place, + const std::string &name) + : op_(framework::OpRegistry::CreateOp(op_desc)), + local_scope_(local_scope), + place_(place), + name_(name) {} + +void RPCOpHandle::RunImpl() { + // TODO(wuyi): need further analysis whether wait VarDummyHandle. + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } + if (in->generated_op_) { + in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + } + } + auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); + // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead + // lock. + op_->Run(*tmp_scope, place_); +} + +std::string RPCOpHandle::Name() const { return name_; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h new file mode 100644 index 0000000000..d28b772172 --- /dev/null +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -0,0 +1,52 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace details { + +struct RPCOpHandle : public OpHandleBase { + RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place, const std::string& name); + + std::string Name() const override; + + // Delay and buffer nccl_all_reduce together can significantly increase + // performance. Disable this feature by returning false. + bool IsMultiDeviceTransfer() override { return false; }; + + protected: + void RunImpl() override; + + private: + std::unique_ptr op_; + const Scope* local_scope_; + const platform::Place& place_; + const std::string name_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 067e0c2b83..387e06bca6 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -14,6 +14,7 @@ #pragma once #include +#include // NOLINT #include #include #include @@ -38,6 +39,7 @@ class Variable { template T* GetMutable() { + std::unique_lock lock(mutex_); if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); } @@ -90,6 +92,7 @@ class Variable { // by its address but not the unreadable name. friend class Scope; const std::string* name_; + std::mutex mutex_; }; } // namespace framework diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ae60ab1532..ca0518d4dc 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { @@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, } void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); + const auto ch = GetChannel(ep, ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); @@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { } void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); + const auto ch = GetChannel(ep, ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); s->Prepare(time_out); @@ -243,12 +243,19 @@ bool RPCClient::Proceed() { delete c; return true; } - -std::shared_ptr RPCClient::GetChannel(const std::string& ep) { - auto it = channels_.find(ep); +std::shared_ptr RPCClient::GetChannel(const std::string& ep, + const std::string& key) { + VLOG(3) << "this addr: " << this; + std::unique_lock lock(mutex_); + auto it = channels_.find(key); if (it != channels_.end()) { + VLOG(3) << "find ep: " << ep; return it->second; } + VLOG(3) << "can not find ep: " << ep; + for (auto it = channels_.begin(); it != channels_.end(); ++it) { + VLOG(3) << "ep: " << it->first; + } grpc::ChannelArguments args; args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); @@ -257,8 +264,7 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { auto ch = grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); - - channels_[ep] = ch; + channels_[key] = ch; return ch; } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index dabce7414d..4e1d608549 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include #include +#include // NOLINT #include #include @@ -190,12 +191,14 @@ class RPCClient { private: bool Proceed(); - std::shared_ptr GetChannel(const std::string& ep); + std::shared_ptr GetChannel(const std::string& ep, + const std::string& key); private: grpc::CompletionQueue cq_; std::map> channels_; int64_t req_count_ = 0; + std::mutex mutex_; }; } // namespace detail diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 3dfdd135ee..5d2e558699 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 3b5459f3e3..7ca3c20c7d 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); auto client_var_name = Output("RPCClient"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); detail::RPCClient* rpc_client = client_var->GetMutable(); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 1ce0907f3a..05e2623630 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/send_recv_util.h index 113513eb6b..deab005149 100644 --- a/paddle/fluid/operators/send_recv_util.h +++ b/paddle/fluid/operators/send_recv_util.h @@ -20,6 +20,9 @@ namespace operators { inline bool NeedSend(const framework::Scope& scope, const std::string& varname) { + // dummy variable is only used in parallel executor to represent + // some dependency relationship, we don't need to send/recv it. + if (varname == "dummy") return false; auto* var = scope.FindVar(varname); PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", varname); diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index f11e84c176..3caceba4e9 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); + VLOG(3) << "client var addr: " << client_var; detail::RPCClient* rpc_client = client_var->GetMutable(); + VLOG(3) << "rpc_client addr: " << rpc_client; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { From e9abc66910a9ee613c60c6ccfcba86f3eed8d429 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 22 May 2018 16:48:40 +0800 Subject: [PATCH 06/17] fix pe --- .../details/computation_op_handle.cc | 2 + .../details/multi_devices_graph_builder.cc | 84 +++++++++++++------ .../details/multi_devices_graph_builder.h | 14 +++- paddle/fluid/operators/detail/grpc_client.cc | 6 -- .../fluid/transpiler/distribute_transpiler.py | 10 +++ 5 files changed, 82 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index df05bb0633..f6e1208a01 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -29,7 +29,9 @@ void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); this->RunAndRecordEvent([this] { + VLOG(3) << "begin run op type is " << op_->Type(); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); + VLOG(3) << "end run op type is " << op_->Type(); }); } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 50998fb8e0..fb5b8608b3 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" -#include #include #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" @@ -79,9 +78,39 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, CreateOpOutput(result, op_handle, each_var_name, p, place_id); } } -bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, - OpDesc *send_op) const { - if (send_op == nullptr) { + +std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( + const ProgramDesc &program) const { + std::vector send_vars; + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "send_vars" || op->Type() == "send") { + auto op_vars = op->InputArgumentNames(); + send_vars.reserve(send_vars.size() + + std::distance(op_vars.begin(), op_vars.end())); + send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); + } + } + return send_vars; +} + +std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( + const ProgramDesc &program) const { + std::vector recv_vars; + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "recv" || op->Type() == "send") { + auto op_vars = op->OutputArgumentNames(); + recv_vars.reserve(recv_vars.size() + + std::distance(op_vars.begin(), op_vars.end())); + recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); + } + } + return recv_vars; +} + +bool MultiDevSSAGraphBuilder::IsDistTrainOp( + const OpDesc &op, const std::vector &send_vars, + const std::vector &recv_vars) const { + if (send_vars.size() == 0 || recv_vars.size() == 0) { return false; } @@ -89,21 +118,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, * Check any of opvars contains `.block` and in sendvars */ auto checker = [](const std::vector &opvars, - const std::vector &sendvars) -> bool { + const std::vector &rpc_vars) -> bool { for (auto &var : opvars) { if (var.find(".block") != std::string::npos && - std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { + std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { return true; } } return false; }; - if (op.Type() == "split" || op.Type() == "split_byref") { - return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); + if (op.Type() == "split" || op.Type() == "split_byref" || + op.Type() == "split_selected_rows") { + return checker(op.OutputArgumentNames(), send_vars); } else if (op.Type() == "concat") { - return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); + return checker(op.InputArgumentNames(), recv_vars); } + return false; } @@ -132,8 +163,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); - // Find "send" op first for split is in front of send. - OpDesc *send_op = GetSendOpDesc(program); + // find send/recv vars so that we can place the distributed training + // realted op in the place 0 + auto send_vars = FindDistTrainSendVars(program); + auto recv_vars = FindDistTrainRecvVars(program); size_t cur_device_id = 0; std::vector> var_name_on_devices; @@ -147,8 +180,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // append rpc op if program is distributed trainer main program. // always use the first device CreateRPCOp(&result, *op); - } else if (IsDistTrainOp(*op, send_op)) { - CreateComputationalOps(&result, *op, 1); + } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { + // CreateComputationalOps(&result, *op, 1); + CreateComputationalOp(&result, *op, 0); } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -213,9 +247,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::string filename = "/tmp/graph"; - std::ofstream fout(filename); - PrintGraphviz(*graph, fout); + std::ostringstream sout; + PrintGraphviz(*graph, sout); + VLOG(10) << sout.str(); } return std::unique_ptr(graph); @@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( } return nullptr; } + void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA @@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } -void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, - std::string op_name) const { +void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, + const std::string &prev_op_name) const { for (auto &prev_op : result->ops_) { - if (prev_op->Name() == op_name) { + if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(); prev_op->AddOutput(dep_var); result->dep_vars_.emplace(dep_var); - result->ops_.back().get()->AddInput(dep_var); + op->AddInput(dep_var); } } } @@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { auto &p = places_[0]; auto *s = local_scopes_[0]; - VLOG(3) << "create rpc op: " << op.Type(); result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + if (op.Type() == "send_barrier") { - ConnectOp(result, "send_vars"); + ConnectOp(result, result->ops_.back().get(), "send_vars"); } else if (op.Type() == "recv") { - ConnectOp(result, "send_barrier"); + ConnectOp(result, result->ops_.back().get(), "send_barrier"); } else if (op.Type() == "fetch_barrier") { - ConnectOp(result, "recv"); + ConnectOp(result, result->ops_.back().get(), "recv"); } else if (op.Type() == "send" || op.Type() == "send_vars") { // do nothing } else { @@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, } // FIXME(wuyi): send op always copy from GPU 0 - // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); // Create inputs for output on original place and no ssa output // is created for send op. CreateOpHandleIOs(result, op, 0); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 45713b0c4f..1d0021c954 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -64,17 +64,25 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; - void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. */ - bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + bool IsDistTrainOp(const OpDesc &op, + const std::vector &send_vars, + const std::vector &recv_vars) const; + + std::vector FindDistTrainSendVars( + const ProgramDesc &program) const; + + std::vector FindDistTrainRecvVars( + const ProgramDesc &program) const; bool IsRPCOp(const OpDesc &op) const; - void ConnectOp(SSAGraph *result, std::string op_name) const; + void ConnectOp(SSAGraph *result, OpHandleBase *op, + const std::string &prev_op_name) const; void CreateComputationalOps(SSAGraph *result, const OpDesc &op, size_t num_places) const; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ca0518d4dc..a758205938 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -245,17 +245,11 @@ bool RPCClient::Proceed() { } std::shared_ptr RPCClient::GetChannel(const std::string& ep, const std::string& key) { - VLOG(3) << "this addr: " << this; std::unique_lock lock(mutex_); auto it = channels_.find(key); if (it != channels_.end()) { - VLOG(3) << "find ep: " << ep; return it->second; } - VLOG(3) << "can not find ep: " << ep; - for (auto it = channels_.begin(); it != channels_.end(); ++it) { - VLOG(3) << "ep: " << it->first; - } grpc::ChannelArguments args; args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 806cc2fcc1..cf7775e8ed 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -373,6 +373,16 @@ class DistributeTranspiler: for i, ep in enumerate(eplist): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + # step4: Concat the parameters splits together after recv. + for varname, splited_var in param_var_mapping.iteritems(): + if len(splited_var) <= 1: + continue + orig_param = program.global_block().vars[varname] + program.global_block().append_op( + type="concat", + inputs={"X": splited_var}, + outputs={"Out": [orig_param]}, + attrs={"axis": 0}) # TODO(Yancey1989): check dist lookup table if self.has_distributed_lookup_table: From 147d54ba621966c5b8b1e8740f1f3e2cad00b98e Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 22 May 2018 17:33:58 +0800 Subject: [PATCH 07/17] update --- paddle/fluid/framework/details/computation_op_handle.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index f6e1208a01..df05bb0633 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -29,9 +29,7 @@ void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); this->RunAndRecordEvent([this] { - VLOG(3) << "begin run op type is " << op_->Type(); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); - VLOG(3) << "end run op type is " << op_->Type(); }); } From 6debbcd9f9d88bb2308aa9de7c02cc1b9a09d08f Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 23 May 2018 11:44:09 +0800 Subject: [PATCH 08/17] connect fetch barrier and concat op --- .../details/multi_devices_graph_builder.cc | 17 +++++--- .../details/multi_devices_graph_builder.h | 1 + paddle/fluid/operators/recv_op.cc | 9 ++++- paddle/fluid/operators/send_vars_op.cc | 2 - .../fluid/transpiler/distribute_transpiler.py | 39 ++++++++++++++----- 5 files changed, 50 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index fb5b8608b3..52e691a617 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include #include #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" @@ -181,8 +182,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // always use the first device CreateRPCOp(&result, *op); } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { - // CreateComputationalOps(&result, *op, 1); - CreateComputationalOp(&result, *op, 0); + CreateDistTrainOp(&result, *op); } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -247,9 +247,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::ostringstream sout; - PrintGraphviz(*graph, sout); - VLOG(10) << sout.str(); + std::ofstream fout("/tmp/graph.dot"); + PrintGraphviz(*graph, fout); } return std::unique_ptr(graph); @@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, } } +void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, + const OpDesc &op) const { + CreateComputationalOp(result, op, 0); + if (op.Type() == "concat") { + ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); + } +} + void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { auto &p = places_[0]; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 1d0021c954..cef21e4650 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; + void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 7ca3c20c7d..1255ed4c49 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); auto client_var_name = Output("RPCClient"); + int sync_recv = Attr("sync_recv"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -54,7 +55,9 @@ class RecvOp : public framework::OperatorBase { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - PADDLE_ENFORCE(rpc_client->Wait()); + if (sync_recv) { + PADDLE_ENFORCE(rpc_client->Wait()); + } } }; @@ -75,6 +78,10 @@ This operator can get variables from server side. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); + AddAttr("sync_recv", + "(int, default 0)" + "sync recv or async recv.") + .SetDefault(0); } }; diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index 3caceba4e9..8d5b5f4292 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase { "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); - VLOG(3) << "client var addr: " << client_var; detail::RPCClient* rpc_client = client_var->GetMutable(); - VLOG(3) << "rpc_client addr: " << rpc_client; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index cf7775e8ed..e6a4e64e7f 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -357,12 +357,35 @@ class DistributeTranspiler: ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) - program.global_block().append_op( - type="recv", - inputs={}, - outputs={"Out": recv_vars, - "RPCClient": rpc_client_var}, - attrs={"epmap": eplist}) + #program.global_block().append_op( + # type="recv", + # inputs={}, + # outputs={"Out": recv_vars, + # "RPCClient": rpc_client_var}, + # attrs={"epmap": eplist}) + + #program.global_block().append_op( + # type="fetch_barrier", + # inputs={}, + # outputs={"RPCClient": rpc_client_var}, + # attrs={"endpoints": pserver_endpoints}) + + for i, ep in enumerate(eplist): + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + # step4: Concat the parameters splits together after recv. + for varname, splited_var in param_var_mapping.iteritems(): + eps = [] + for var in splited_var: + index = [v.name for v in recv_vars].index(var.name) + eps.append(eplist[index]) + + program.global_block().append_op( + type="recv", + inputs={}, + outputs={"Out": splited_var, + "RPCClient": rpc_client_var}, + attrs={"epmap": eps}) program.global_block().append_op( type="fetch_barrier", @@ -370,10 +393,6 @@ class DistributeTranspiler: outputs={"RPCClient": rpc_client_var}, attrs={"endpoints": pserver_endpoints}) - for i, ep in enumerate(eplist): - self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) - self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) - # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: continue From 540b45350d252e008b7c786e214aa9b096144f9a Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 23 May 2018 15:06:37 +0800 Subject: [PATCH 09/17] use req_count as atomic type --- .../framework/details/multi_devices_graph_builder.cc | 10 ---------- .../framework/details/multi_devices_graph_builder.h | 6 ------ paddle/fluid/operators/detail/grpc_client.cc | 1 - paddle/fluid/operators/detail/grpc_client.h | 2 +- .../paddle/fluid/transpiler/distribute_transpiler.py | 1 - 5 files changed, 1 insertion(+), 19 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 52e691a617..131989b2b0 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -298,16 +298,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, CreateOpHandleIOs(result, op, dev_id); } -OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( - const ProgramDesc &program) const { - for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send") { - return op; - } - } - return nullptr; -} - void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index cef21e4650..be17c2a92e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -107,12 +107,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const; - /** - * Get send op in the global block of program. - * nullptr if not found. - */ - OpDesc *GetSendOpDesc(const ProgramDesc &program) const; - bool IsSparseGradient( const std::unordered_map &var_types, const std::string &og) const; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index a758205938..d8d0075934 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -59,7 +59,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, call->StartCall(); call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); }); - req_count_++; return true; diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 4e1d608549..6f8b67be3e 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -197,7 +197,7 @@ class RPCClient { private: grpc::CompletionQueue cq_; std::map> channels_; - int64_t req_count_ = 0; + std::atomic req_count_{0}; std::mutex mutex_; }; diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index e6a4e64e7f..848cb0bd6c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -403,7 +403,6 @@ class DistributeTranspiler: outputs={"Out": [orig_param]}, attrs={"axis": 0}) - # TODO(Yancey1989): check dist lookup table if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, eplist) From fc06222ae91990d6eaece2c9895b869742000eae Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 24 May 2018 12:26:50 +0800 Subject: [PATCH 10/17] fix async worker --- paddle/fluid/operators/send_barrier_op.cc | 13 ++++++++----- .../fluid/tests/unittests/test_dist_transpiler.py | 11 +++++++---- .../fluid/transpiler/distribute_transpiler.py | 5 ++++- python/paddle/fluid/transpiler/ps_dispatcher.py | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 05e2623630..354eb4fa13 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + bool sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -51,12 +52,13 @@ class SendBarrierOp : public framework::OperatorBase { // need to wait before sending send_barrier message PADDLE_ENFORCE(rpc_client->Wait()); - - for (auto& ep : eps) { - VLOG(3) << "send barrier, ep: " << ep; - rpc_client->AsyncSendBatchBarrier(ep); + if (sync_mode) { + for (auto& ep : eps) { + VLOG(3) << "send barrier, ep: " << ep; + rpc_client->AsyncSendBatchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(rpc_client->Wait()); } }; @@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent. "(string vector, default 127.0.0.1:6164)" "Server endpoints to send variables to.") .SetDefault({"127.0.0.1:6164"}); + AddAttr("sync_mode", "work in sync_mode or not").SetDefault(true); } }; diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 10f8c4f3f0..fa49bd41a5 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase): def test_transpiler(self): trainer = self.get_trainer() pserver, startup = self.get_pserver(self.current_pserver_ep) - self.assertEqual([op.type for op in trainer.global_block().ops], self.get_expect_trainer_ops()) @@ -67,7 +66,7 @@ class TestDistTranspiler(unittest.TestCase): "fill_constant", "fill_constant", "uniform_random", "uniform_random" ]) - # the variable #fc_w will be split into two blocks + # the variable #fc_w will be split into two blocks fc_w_var = startup.global_block().var("fc_w.block1") self.assertEqual(fc_w_var.shape, (500, 1000)) @@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase): optimize_ops, params_grads = self.net_conf() delete_ops(trainer.global_block(), optimize_ops) - return [op.type for op in trainer.global_block().ops - ] + ["split_byref", "send", "concat"] + ops = [op.type for op in trainer.global_block().ops] + [ + "split_byref", "send_vars", "send_barrier", "recv", "recv", + "fetch_barrier", "concat" + ] + ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars") + return ops def get_trainer(self): return self._transpiler_instance().get_trainer_program() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 848cb0bd6c..72a02f24a3 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -348,7 +348,10 @@ class DistributeTranspiler: type="send_barrier", inputs={}, outputs={"RPCClient": rpc_client_var}, - attrs={"endpoints": pserver_endpoints}) + attrs={ + "endpoints": pserver_endpoints, + "sync_mode": self.sync_mode + }) # step 3.2: insert recv op to receive parameters from parameter server recv_vars = [] diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py index dffe66998a..9ba3bf8216 100644 --- a/python/paddle/fluid/transpiler/ps_dispatcher.py +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -15,7 +15,7 @@ class PSDispatcher(object): """ - DistributedSpliter is the base class for dispatching vars + PSDispatcher is the base class for dispatching vars into different pserver instance. You need to implement the `dispatch` inferface. """ From 268e9dc1c6229b28d73e5290c82302eb53e6d1e6 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Sun, 27 May 2018 19:31:16 +0800 Subject: [PATCH 11/17] polish code --- .../details/multi_devices_graph_builder.cc | 22 +++++++++++-------- paddle/fluid/operators/detail/grpc_client.cc | 17 +++++++------- paddle/fluid/operators/detail/grpc_client.h | 3 +-- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 17b9c39201..14b73b3681 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( const ProgramDesc &program) const { std::vector send_vars; + // since parameters are all in block 0, + // it's enough to only scan send ops in block 0 for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send_vars" || op->Type() == "send") { + // TODO(Yancey1989): use a graceful method to find send op, + // instead of the the hard code string + if (op->Type() == "send_vars") { auto op_vars = op->InputArgumentNames(); send_vars.reserve(send_vars.size() + std::distance(op_vars.begin(), op_vars.end())); @@ -99,7 +103,9 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( const ProgramDesc &program) const { std::vector recv_vars; for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "recv" || op->Type() == "send") { + // TODO(Yancey1989): use a graceful method to find recv op, + // instead of the hard code string + if (op->Type() == "recv") { auto op_vars = op->OutputArgumentNames(); recv_vars.reserve(recv_vars.size() + std::distance(op_vars.begin(), op_vars.end())); @@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( auto checker = [](const std::vector &opvars, const std::vector &rpc_vars) -> bool { for (auto &var : opvars) { + // a variable name with the suffix `.block` means it's a splited + // variable by (DistributeTranspiler) + // [python/paddle/fluid/transpiler/distribute_transpiler.py] if (var.find(".block") != std::string::npos && std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { return true; @@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( return false; }; - if (op.Type() == "split" || op.Type() == "split_byref" || - op.Type() == "split_selected_rows") { - return checker(op.OutputArgumentNames(), send_vars); - } else if (op.Type() == "concat") { - return checker(op.InputArgumentNames(), recv_vars); - } - + return checker(op.OutputArgumentNames(), send_vars) || + checker(op.InputArgumentNames(), recv_vars); return false; } diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index f2385abed5..51f0d2a742 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); + const auto ch = GetChannel(ep_val); framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { @@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); + const auto ch = GetChannel(ep_val); framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val); + const auto ch = GetChannel(ep_val); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, } void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep, ep); + const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); @@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { } void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep, ep); + const auto ch = GetChannel(ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); s->Prepare(time_out); @@ -248,10 +248,9 @@ bool RPCClient::Proceed() { delete c; return true; } -std::shared_ptr RPCClient::GetChannel(const std::string& ep, - const std::string& key) { +std::shared_ptr RPCClient::GetChannel(const std::string& ep) { std::unique_lock lock(mutex_); - auto it = channels_.find(key); + auto it = channels_.find(ep); if (it != channels_.end()) { return it->second; } @@ -263,7 +262,7 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep, auto ch = grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); - channels_[key] = ch; + channels_[ep] = ch; return ch; } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 6f8b67be3e..e5007b509a 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -191,8 +191,7 @@ class RPCClient { private: bool Proceed(); - std::shared_ptr GetChannel(const std::string& ep, - const std::string& key); + std::shared_ptr GetChannel(const std::string& ep); private: grpc::CompletionQueue cq_; From ad6c0142c41cd823fbe85dbb944efbae130e8b3d Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 28 May 2018 15:37:29 +0800 Subject: [PATCH 12/17] clean up codes --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../details/multi_devices_graph_builder.cc | 11 ++-- .../fluid/framework/details/rpc_op_handle.cc | 1 + .../fluid/framework/details/send_op_handle.cc | 49 ------------------ .../fluid/framework/details/send_op_handle.h | 51 ------------------- paddle/fluid/framework/variable.h | 1 + paddle/fluid/inference/analysis/device.h | 2 - paddle/fluid/operators/detail/grpc_client.cc | 1 + paddle/fluid/operators/recv_op.cc | 6 +-- .../fluid/transpiler/distribute_transpiler.py | 13 ----- .../paddle/fluid/transpiler/ps_dispatcher.py | 2 +- 11 files changed, 12 insertions(+), 128 deletions(-) delete mode 100644 paddle/fluid/framework/details/send_op_handle.cc delete mode 100644 paddle/fluid/framework/details/send_op_handle.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 600c47ad5f..1bcd8412eb 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -3,7 +3,6 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) @@ -27,7 +26,7 @@ endif() cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 14b73b3681..25711e0e47 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -19,7 +19,6 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" -#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" @@ -141,7 +140,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( return checker(op.OutputArgumentNames(), send_vars) || checker(op.InputArgumentNames(), recv_vars); - return false; } bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const { @@ -471,17 +469,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ConnectOp(result, result->ops_.back().get(), "send_barrier"); } else if (op.Type() == "fetch_barrier") { ConnectOp(result, result->ops_.back().get(), "recv"); - } else if (op.Type() == "send" || op.Type() == "send_vars") { + } else if (op.Type() == "send_vars") { // do nothing } else { PADDLE_THROW( - "rpc op should be in [send," + "rpc op should be in [" "send_vars, send_barrier. recv, fetch_barrier]"); } - // FIXME(wuyi): send op always copy from GPU 0 - // Create inputs for output on original place and no ssa output - // is created for send op. + // TODO(Yancey1989): schedule rpc op on different place may + // increate throughput CreateOpHandleIOs(result, op, 0); } diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 03f53421b1..7f4da4c01d 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -31,6 +31,7 @@ void RPCOpHandle::RunImpl() { // Wait input done for (auto *in : inputs_) { auto &p = static_cast(in)->place_; + // FIXME(Yancey1989): need a better solution instead of use DebugString() if (in->DebugString() == "dummy") { // HACK continue; } diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc deleted file mode 100644 index 7109659dd7..0000000000 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/send_op_handle.h" - -namespace paddle { -namespace framework { -namespace details { - -SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope, - const platform::Place &place) - : op_(framework::OpRegistry::CreateOp(op_desc)), - local_scope_(local_scope), - place_(place) {} - -void SendOpHandle::RunImpl() { - // TODO(wuyi): need further analysis whether wait VarDummyHandle. - // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - if (in->DebugString() == "dummy") { // HACK - continue; - } - if (in->generated_op_) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); - } - } - auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); - // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead - // lock. - op_->Run(*tmp_scope, place_); -} - -std::string SendOpHandle::Name() const { return "send"; } -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h deleted file mode 100644 index 2f78811fad..0000000000 --- a/paddle/fluid/framework/details/send_op_handle.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/details/op_handle_base.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/scope.h" - -namespace paddle { -namespace framework { -namespace details { - -struct SendOpHandle : public OpHandleBase { - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const platform::Place& place); - - std::string Name() const override; - - // Delay and buffer nccl_all_reduce together can significantly increase - // performance. Disable this feature by returning false. - bool IsMultiDeviceTransfer() override { return false; }; - - protected: - void RunImpl() override; - - private: - std::unique_ptr op_; - const Scope* local_scope_; - const platform::Place& place_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 387e06bca6..e7f87ab6f8 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -39,6 +39,7 @@ class Variable { template T* GetMutable() { + // TODO(Yancey1989): need to make Variable completely thread-safe. std::unique_lock lock(mutex_); if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); diff --git a/paddle/fluid/inference/analysis/device.h b/paddle/fluid/inference/analysis/device.h index 9fad445ede..585c992329 100644 --- a/paddle/fluid/inference/analysis/device.h +++ b/paddle/fluid/inference/analysis/device.h @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#pragma once - namespace paddle { namespace inference { namespace analysis { diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 51f0d2a742..4c9c7be40c 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -249,6 +249,7 @@ bool RPCClient::Proceed() { return true; } std::shared_ptr RPCClient::GetChannel(const std::string& ep) { + // TODO(Yancey1989): make grpc client completely thread-safe std::unique_lock lock(mutex_); auto it = channels_.find(ep); if (it != channels_.end()) { diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 1255ed4c49..d416ba1e1f 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -38,7 +38,7 @@ class RecvOp : public framework::OperatorBase { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); auto client_var_name = Output("RPCClient"); - int sync_recv = Attr("sync_recv"); + int sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -55,7 +55,7 @@ class RecvOp : public framework::OperatorBase { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - if (sync_recv) { + if (sync_mode) { PADDLE_ENFORCE(rpc_client->Wait()); } } @@ -78,7 +78,7 @@ This operator can get variables from server side. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); - AddAttr("sync_recv", + AddAttr("sync_mode", "(int, default 0)" "sync recv or async recv.") .SetDefault(0); diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 72a02f24a3..a9de5419fa 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -360,19 +360,6 @@ class DistributeTranspiler: ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) - #program.global_block().append_op( - # type="recv", - # inputs={}, - # outputs={"Out": recv_vars, - # "RPCClient": rpc_client_var}, - # attrs={"epmap": eplist}) - - #program.global_block().append_op( - # type="fetch_barrier", - # inputs={}, - # outputs={"RPCClient": rpc_client_var}, - # attrs={"endpoints": pserver_endpoints}) - for i, ep in enumerate(eplist): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py index 9ba3bf8216..d6a6867752 100644 --- a/python/paddle/fluid/transpiler/ps_dispatcher.py +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -41,7 +41,7 @@ class PSDispatcher(object): class HashName(PSDispatcher): """ - Hash variable names to servral endpoints + Hash variable names to several endpoints """ def __init__(self, pserver_endpoints): From 28596a3386349c939c6e513e1d9fdc8a78553312 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 28 May 2018 15:50:39 +0800 Subject: [PATCH 13/17] add gflag ssa_graph_path --- .../fluid/framework/details/multi_devices_graph_builder.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 25711e0e47..d21e0f7b96 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -29,6 +29,10 @@ #include #include +DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot", + "the ssa graph path only print with GLOG_v=10," + "default /tmp/graph.dot"); + namespace paddle { namespace framework { namespace details { @@ -264,7 +268,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::ofstream fout("/tmp/graph.dot"); + std::ofstream fout(FLAGS_ssa_graph_path); PrintGraphviz(*graph, fout); } From 20c24c05aa79f2c07f4fcc509feff816ec133a04 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 29 May 2018 11:37:05 +0800 Subject: [PATCH 14/17] singleton rpc_client --- .../details/multi_devices_graph_builder.cc | 13 +- .../details/multi_devices_graph_builder.h | 2 - paddle/fluid/framework/op_proto_maker.cc | 2 +- paddle/fluid/framework/op_proto_maker.h | 1 + .../analysis/data_flow_graph_tester.cc | 4 +- paddle/fluid/operators/detail/grpc_client.cc | 15 ++ paddle/fluid/operators/detail/grpc_client.h | 10 ++ .../operators/detail/grpc_server_test.cc | 11 +- paddle/fluid/operators/fetch_barrier_op.cc | 22 +-- paddle/fluid/operators/prefetch_op.cc | 22 +-- paddle/fluid/operators/recv_op.cc | 10 +- paddle/fluid/operators/send_barrier_op.cc | 23 +-- paddle/fluid/operators/send_op.cc | 24 +--- paddle/fluid/operators/send_recv_op_test.cc | 136 +++++++++--------- paddle/fluid/operators/send_vars_op.cc | 22 +-- paddle/fluid/pybind/const_value.cc | 3 +- python/paddle/fluid/layers/io.py | 14 +- .../fluid/transpiler/distribute_transpiler.py | 67 +++++---- 18 files changed, 161 insertions(+), 240 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index d21e0f7b96..d8e711994c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( checker(op.InputArgumentNames(), recv_vars); } -bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const { - for (auto &name : op.OutputNames()) { - if (name == "RPCClient") { - return true; - } - } - return false; -} - std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { std::unordered_map var_types; @@ -184,7 +175,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { - if (IsRPCOp(*op)) { + if (boost::get( + op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kRPC)) { // append rpc op if program is distributed trainer main program. // always use the first device CreateRPCOp(&result, *op); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index be17c2a92e..e07597dbd8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::vector FindDistTrainRecvVars( const ProgramDesc &program) const; - bool IsRPCOp(const OpDesc &op) const; - void ConnectOp(SSAGraph *result, OpHandleBase *op, const std::string &prev_op_name) const; diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 5a4380a83a..ae9f4efd44 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, .InEnum( {static_cast(OpRole::kForward), static_cast(OpRole::kBackward), - static_cast(OpRole::kOptimize), + static_cast(OpRole::kOptimize), static_cast(OpRole::kRPC), static_cast(OpRole::kLoss) | static_cast(OpRole::kForward), static_cast(OpRole::kLoss) | static_cast(OpRole::kBackward), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 9bd6ca6ea3..8493b9d8b3 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -24,6 +24,7 @@ enum class OpRole { kForward = 0x0000, kBackward = 0x0001, kOptimize = 0x0002, + kRPC = 0x0003, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc index 51d38d6251..9d7cceeb65 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc @@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) { GraphTraits trait(&dfg); auto nodes = trait.nodes(); - int count = 0; + size_t count = 0; for (auto it = nodes.begin(); it != nodes.end(); ++it) { LOG(INFO) << "visiting " << it->name(); ++count; @@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) { dfg.Build(); GraphTraits trait(&dfg); auto nodes = trait.nodes_in_DFS(); - int count = 0; + size_t count = 0; for (auto it = nodes.begin(); it != nodes.end(); ++it) { LOG(INFO) << "visiting " << it->name(); ++count; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 4c9c7be40c..f7ce778687 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,6 +25,21 @@ namespace paddle { namespace operators { namespace detail { +std::once_flag RPCClient::init_flag_; + +std::unique_ptr RPCClient::rpc_client_(nullptr); + +RPCClient* RPCClient::GetInstance() { + std::call_once(init_flag_, &RPCClient::Init); + return rpc_client_.get(); +} + +void RPCClient::Init() { + if (rpc_client_.get() == nullptr) { + rpc_client_.reset(new RPCClient()); + } +} + bool RPCClient::AsyncSendVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index e5007b509a..449d5105af 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -36,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { namespace operators { @@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor { class RPCClient { public: + RPCClient() {} + + static RPCClient* GetInstance(); + bool AsyncSendVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, @@ -192,12 +197,17 @@ class RPCClient { private: bool Proceed(); std::shared_ptr GetChannel(const std::string& ep); + // Init is called by GetInstance. + static void Init(); private: grpc::CompletionQueue cq_; std::map> channels_; std::atomic req_count_{0}; std::mutex mutex_; + static std::unique_ptr rpc_client_; + static std::once_flag init_flag_; + DISABLE_COPY_AND_ASSIGN(RPCClient); }; } // namespace detail diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 73e75c9087..264e3c6671 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - detail::RPCClient client; - client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, - out_var_name); - client.Wait(); + detail::RPCClient::GetInstance(); + + // detail::RPCClient::GetInstance(); + // client->Wait(); + // client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, + // out_var_name); + // client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 5d2e558699..79ec02f520 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); PADDLE_ENFORCE(rpc_client->Wait()); @@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase { class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( SendBarrier operator @@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent. } }; -class FetchBarrierOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class FetchBarrierOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -103,5 +84,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp, paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker, - ops::FetchBarrierOpVarTypeInference, ops::FetchBarrierOpShapeInference); diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 4cfea958e8..e0a9b24ac8 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which will be" - "initialized at most once."); AddOutput("Out", "(LoDTensor) result " "to be fetched from parameter server") @@ -87,17 +79,6 @@ the parameter server and fetch result back. } }; -class PrefetchOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class PrefetchOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -110,5 +91,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(prefetch, ops::PrefetchOp, paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, - ops::PrefetchOpVarTypeInference, ops::PrefetchOpShapeInference); diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index d416ba1e1f..d8ddb7b448 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -37,7 +37,6 @@ class RecvOp : public framework::OperatorBase { const platform::Place& place) const override { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); - auto client_var_name = Output("RPCClient"); int sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); @@ -45,11 +44,7 @@ class RecvOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; @@ -65,9 +60,6 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( Recv operator diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 354eb4fa13..2c77ee2e27 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -43,12 +43,8 @@ class SendBarrierOp : public framework::OperatorBase { auto& ctx = *pool.Get(place); // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + + auto rpc_client = detail::RPCClient::GetInstance(); // need to wait before sending send_barrier message PADDLE_ENFORCE(rpc_client->Wait()); @@ -65,9 +61,6 @@ class SendBarrierOp : public framework::OperatorBase { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( SendBarrier operator @@ -83,17 +76,6 @@ the Parameter Server would knew all variables have been sent. } }; -class SendBarrierOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendBarrierOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -106,5 +88,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp, paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker, - ops::SendBarrierOpVarTypeInference, ops::SendBarrierOpShapeInference); diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 95bb1f3c69..a5150f242c 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddOutput("Out", "(Tensor) Output tensor to be received from server") .AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( Send operator @@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server. } }; -class SendOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase { namespace ops = paddle::operators; REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, - ops::SendOpMaker, ops::SendOpVarTypeInference, - ops::SendOpShapeInference); + ops::SendOpMaker, ops::SendOpShapeInference); diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index d5303eaf50..2b3dc81676 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -177,75 +177,75 @@ TEST(SendRecvOp, CPUDense) { attrs.insert({"epmap", std::vector({endpoint})}); auto send_op = f::OpRegistry::CreateOp( "send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); - send_op->Run(scope, place); - - auto in_var = scope.Var("x1"); - auto tensor = in_var->GetMutable(); - float *expected = tensor->data(); - auto out_var = scope.Var("Out"); - auto target = out_var->GetMutable(); - // x1 * 2 == x0 - EXPECT_NE(target->memory_size(), size_t(0)); - float *actual = target->data(); - for (int64_t i = 0; i < target->numel(); ++i) { - EXPECT_EQ(expected[i] * 2, actual[i]); - } - listen_and_serv_op->Stop(); - server_thread.join(); - listen_and_serv_op.reset(nullptr); - paddle::operators::ListenAndServOp::ResetPort(); + {{"Out", {"Out"}}, attrs); + send_op->Run(scope, place); + + auto in_var = scope.Var("x1"); + auto tensor = in_var->GetMutable(); + float *expected = tensor->data(); + auto out_var = scope.Var("Out"); + auto target = out_var->GetMutable(); + // x1 * 2 == x0 + EXPECT_NE(target->memory_size(), size_t(0)); + float *actual = target->data(); + for (int64_t i = 0; i < target->numel(); ++i) { + EXPECT_EQ(expected[i] * 2, actual[i]); + } + listen_and_serv_op->Stop(); + server_thread.join(); + listen_and_serv_op.reset(nullptr); + paddle::operators::ListenAndServOp::ResetPort(); } TEST(SendRecvOp, CPUSparse) { - std::atomic initialized; - initialized = false; - std::thread server_thread(StartServerNet, true, &initialized); - while (!initialized) { - } - auto *listen_and_serv_op_ptr = - static_cast( - listen_and_serv_op.get()); - ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); - listen_and_serv_op_ptr->WaitServerReady(); - - // local net - f::Scope scope; - p::CPUPlace place; - p::CPUDeviceContext ctx(place); - InitSelectedRowsInScope(place, &scope); - scope.Var("RPC_CLIENT_VAR"); - f::AttributeMap attrs; - selected_port = listen_and_serv_op_ptr->GetSelectedPort(); - std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); - attrs.insert({"endpoints", std::vector({endpoint})}); - attrs.insert({"epmap", std::vector({endpoint})}); - auto send_op = f::OpRegistry::CreateOp( - "send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); - send_op->Run(scope, place); - - auto x0 = scope.Var("x0")->GetMutable(); - auto x1 = scope.Var("x1")->GetMutable(); - auto out = scope.Var("Out")->GetMutable(); - auto actual = out->mutable_value(); - - std::unique_ptr expect{new f::SelectedRows()}; - auto expect_value = expect->mutable_value(); - expect_value->mutable_data(f::make_ddim({5, 10}), place); - - m::SelectedRowsAdd add_functor; - add_functor(ctx, *x0, *x1, expect.get()); - - EXPECT_EQ(actual->numel(), expect_value->numel()); - EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size()); - - for (int64_t i = 0; i < expect_value->numel(); ++i) { - EXPECT_EQ(expect_value->mutable_data(place)[i], - actual->mutable_data(place)[i]); - } - listen_and_serv_op->Stop(); - server_thread.join(); - listen_and_serv_op.reset(); - paddle::operators::ListenAndServOp::ResetPort(); + std::atomic initialized; + initialized = false; + std::thread server_thread(StartServerNet, true, &initialized); + while (!initialized) { + } + auto *listen_and_serv_op_ptr = + static_cast( + listen_and_serv_op.get()); + ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); + listen_and_serv_op_ptr->WaitServerReady(); + + // local net + f::Scope scope; + p::CPUPlace place; + p::CPUDeviceContext ctx(place); + InitSelectedRowsInScope(place, &scope); + scope.Var("RPC_CLIENT_VAR"); + f::AttributeMap attrs; + selected_port = listen_and_serv_op_ptr->GetSelectedPort(); + std::string endpoint = + paddle::string::Sprintf("127.0.0.1:%d", selected_port); + attrs.insert({"endpoints", std::vector({endpoint})}); + attrs.insert({"epmap", std::vector({endpoint})}); + auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}}, attrs); + send_op->Run(scope, place); + + auto x0 = scope.Var("x0")->GetMutable(); + auto x1 = scope.Var("x1")->GetMutable(); + auto out = scope.Var("Out")->GetMutable(); + auto actual = out->mutable_value(); + + std::unique_ptr expect{new f::SelectedRows()}; + auto expect_value = expect->mutable_value(); + expect_value->mutable_data(f::make_ddim({5, 10}), place); + + m::SelectedRowsAdd add_functor; + add_functor(ctx, *x0, *x1, expect.get()); + + EXPECT_EQ(actual->numel(), expect_value->numel()); + EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size()); + + for (int64_t i = 0; i < expect_value->numel(); ++i) { + EXPECT_EQ(expect_value->mutable_data(place)[i], + actual->mutable_data(place)[i]); + } + listen_and_serv_op->Stop(); + server_thread.join(); + listen_and_serv_op.reset(); + paddle::operators::ListenAndServOp::ResetPort(); } diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index 8d5b5f4292..fe839dab69 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -45,12 +45,7 @@ class SendVarsOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -73,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker { void Make() { AddInput("X", "(Tensor, SelectedRows) Input variables to be sent") .AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which will be" - "initialized at most once."); AddComment(R"DOC( Send operator @@ -93,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server } }; -class SendVarsOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendVarsOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -116,5 +97,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(send_vars, ops::SendVarsOp, paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker, - ops::SendVarsOpVarTypeInference, ops::SendVarsOpShapeInference); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 9111abca5a..76aa7d2010 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) { .value("Forward", framework::OpRole::kForward) .value("Backward", framework::OpRole::kBackward) .value("Optimize", framework::OpRole::kOptimize) - .value("Loss", framework::OpRole::kLoss); + .value("Loss", framework::OpRole::kLoss) + .value("RPC", framework::OpRole::kRPC); op_proto_and_checker_maker.def( "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 03d4602f7a..8758ac9f94 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None): endpoints = list(set(epmap)) helper = LayerHelper("Send", **locals()) - rpc_client_var = default_main_program().global_block().create_var( - name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW) if not get_vars: get_vars = [] for s in send_vars: v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True) get_vars.append(v) + rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName() helper.append_op( type="send", inputs={"X": send_vars}, - outputs={"Out": get_vars, - "RPCClient": rpc_client_var}, - attrs={"endpoints": endpoints, - "epmap": epmap}) + outputs={"Out": get_vars}, + attrs={ + "endpoints": endpoints, + "epmap": epmap, + rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC + }) + return get_vars diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a9de5419fa..4e17fdb16b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \ LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" -RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" +RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( +) +RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC class VarBlock: @@ -297,11 +299,6 @@ class DistributeTranspiler: grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ param_var_mapping[p_name][int(p_bid)] - rpc_client_var = program.global_block().create_var( - name=RPC_CLIENT_VAR_NAME, - persistable=True, - type=core.VarDesc.VarType.RAW) - # step 3: transpile trainer side program, insert recv op and send op. # create mapping of endpoint -> split var to create pserver side program @@ -338,8 +335,11 @@ class DistributeTranspiler: index=index + 1, type="send_vars", inputs={"X": splited_vars}, - outputs={"RPCClient": rpc_client_var}, - attrs={"epmap": eplist}) + outputs={}, + attrs={ + "epmap": eplist, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) for _, var in enumerate(splited_vars): send_vars.append(var) @@ -347,10 +347,11 @@ class DistributeTranspiler: program.global_block().append_op( type="send_barrier", inputs={}, - outputs={"RPCClient": rpc_client_var}, + outputs={}, attrs={ "endpoints": pserver_endpoints, - "sync_mode": self.sync_mode + "sync_mode": self.sync_mode, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) # step 3.2: insert recv op to receive parameters from parameter server @@ -373,15 +374,20 @@ class DistributeTranspiler: program.global_block().append_op( type="recv", inputs={}, - outputs={"Out": splited_var, - "RPCClient": rpc_client_var}, - attrs={"epmap": eps}) + outputs={"Out": splited_var}, + attrs={ + "epmap": eps, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) program.global_block().append_op( type="fetch_barrier", inputs={}, - outputs={"RPCClient": rpc_client_var}, - attrs={"endpoints": pserver_endpoints}) + outputs={}, + attrs={ + "endpoints": pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: @@ -394,10 +400,8 @@ class DistributeTranspiler: attrs={"axis": 0}) if self.has_distributed_lookup_table: - self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, - eplist) - self._split_table_grad_and_add_send_vars(program, rpc_client_var, - pserver_endpoints) + self._replace_lookup_table_op_with_prefetch(program, eplist) + self._split_table_grad_and_add_send_vars(program, pserver_endpoints) def get_trainer_program(self): # remove optimize ops and add a send op to main_program @@ -617,8 +621,7 @@ class DistributeTranspiler: return s_prog # transpiler function for dis lookup_table - def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, - eplist): + def _replace_lookup_table_op_with_prefetch(self, program, eplist): # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op self.prefetch_input_vars = None self.prefetch_output_vars = None @@ -665,11 +668,11 @@ class DistributeTranspiler: index=op_index + 1, type="prefetch", inputs={'X': self.prefetch_input_vars}, - outputs={ - "Out": self.prefetch_output_vars, - "RPCClient": rpc_client_var - }, - attrs={"epmap": eplist}) + outputs={"Out": self.prefetch_output_vars}, + attrs={ + "epmap": eplist, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) # insert concat_op program.global_block().insert_op( @@ -689,8 +692,7 @@ class DistributeTranspiler: # break for loop break - def _split_table_grad_and_add_send_vars(self, program, rpc_client_var, - pserver_endpoints): + def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): # 2. add split_ids_op and send_vars_op to send gradient to pservers # there should only be one table_name all_ops = program.global_block().ops @@ -710,9 +712,12 @@ class DistributeTranspiler: index=op_index + 2, type="send_vars", inputs={'X': self.table_grad_list}, - outputs={"RPCClient": rpc_client_var}, - attrs={"sync_send": True, - "epmap": pserver_endpoints}) + outputs={}, + attrs={ + "sync_send": True, + "epmap": pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) break def _create_prefetch_block(self, pserver_index, pserver_program, From 6b91d407dea061503f0d5fec5018d2ffbc551793 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 29 May 2018 11:43:04 +0800 Subject: [PATCH 15/17] revert variable mutex --- paddle/fluid/framework/variable.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index e7f87ab6f8..067e0c2b83 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -14,7 +14,6 @@ #pragma once #include -#include // NOLINT #include #include #include @@ -39,8 +38,6 @@ class Variable { template T* GetMutable() { - // TODO(Yancey1989): need to make Variable completely thread-safe. - std::unique_lock lock(mutex_); if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); } @@ -93,7 +90,6 @@ class Variable { // by its address but not the unreadable name. friend class Scope; const std::string* name_; - std::mutex mutex_; }; } // namespace framework From 8b630ae1b53d24362c5a0e4061f56a75c12eca0b Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 29 May 2018 11:52:39 +0800 Subject: [PATCH 16/17] fix unit test --- paddle/fluid/operators/detail/grpc_server_test.cc | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 264e3c6671..350a7ee123 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -121,13 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - detail::RPCClient::GetInstance(); - - // detail::RPCClient::GetInstance(); - // client->Wait(); - // client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, - // out_var_name); - // client->Wait(); + auto client = detail::RPCClient::GetInstance(); + client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, + out_var_name); + client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); From 5d7c58e46935f74903d131a56fcc43a713a66753 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 29 May 2018 12:28:22 +0800 Subject: [PATCH 17/17] fix code style --- paddle/fluid/operators/send_recv_op_test.cc | 141 ++++++++++---------- 1 file changed, 71 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index 2b3dc81676..e550552b19 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) { std::thread server_thread(StartServerNet, false, &initialized); while (!initialized) { } + static_cast(listen_and_serv_op.get()) ->WaitServerReady(); @@ -175,77 +176,77 @@ TEST(SendRecvOp, CPUDense) { std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); attrs.insert({"endpoints", std::vector({endpoint})}); attrs.insert({"epmap", std::vector({endpoint})}); - auto send_op = f::OpRegistry::CreateOp( - "send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}, attrs); - send_op->Run(scope, place); - - auto in_var = scope.Var("x1"); - auto tensor = in_var->GetMutable(); - float *expected = tensor->data(); - auto out_var = scope.Var("Out"); - auto target = out_var->GetMutable(); - // x1 * 2 == x0 - EXPECT_NE(target->memory_size(), size_t(0)); - float *actual = target->data(); - for (int64_t i = 0; i < target->numel(); ++i) { - EXPECT_EQ(expected[i] * 2, actual[i]); - } - listen_and_serv_op->Stop(); - server_thread.join(); - listen_and_serv_op.reset(nullptr); - paddle::operators::ListenAndServOp::ResetPort(); + const f::VariableNameMap &inputs = {{"X", {"x1"}}}; + const f::VariableNameMap &outputs = {{"Out", {"Out"}}}; + + auto send_op = f::OpRegistry::CreateOp("send", inputs, outputs, attrs); + send_op->Run(scope, place); + + auto in_var = scope.Var("x1"); + auto tensor = in_var->GetMutable(); + float *expected = tensor->data(); + auto out_var = scope.Var("Out"); + auto target = out_var->GetMutable(); + // x1 * 2 == x0 + EXPECT_NE(target->memory_size(), size_t(0)); + float *actual = target->data(); + for (int64_t i = 0; i < target->numel(); ++i) { + EXPECT_EQ(expected[i] * 2, actual[i]); + } + listen_and_serv_op->Stop(); + server_thread.join(); + listen_and_serv_op.reset(nullptr); + paddle::operators::ListenAndServOp::ResetPort(); } TEST(SendRecvOp, CPUSparse) { - std::atomic initialized; - initialized = false; - std::thread server_thread(StartServerNet, true, &initialized); - while (!initialized) { - } - auto *listen_and_serv_op_ptr = - static_cast( - listen_and_serv_op.get()); - ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); - listen_and_serv_op_ptr->WaitServerReady(); - - // local net - f::Scope scope; - p::CPUPlace place; - p::CPUDeviceContext ctx(place); - InitSelectedRowsInScope(place, &scope); - scope.Var("RPC_CLIENT_VAR"); - f::AttributeMap attrs; - selected_port = listen_and_serv_op_ptr->GetSelectedPort(); - std::string endpoint = - paddle::string::Sprintf("127.0.0.1:%d", selected_port); - attrs.insert({"endpoints", std::vector({endpoint})}); - attrs.insert({"epmap", std::vector({endpoint})}); - auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}}, attrs); - send_op->Run(scope, place); - - auto x0 = scope.Var("x0")->GetMutable(); - auto x1 = scope.Var("x1")->GetMutable(); - auto out = scope.Var("Out")->GetMutable(); - auto actual = out->mutable_value(); - - std::unique_ptr expect{new f::SelectedRows()}; - auto expect_value = expect->mutable_value(); - expect_value->mutable_data(f::make_ddim({5, 10}), place); - - m::SelectedRowsAdd add_functor; - add_functor(ctx, *x0, *x1, expect.get()); - - EXPECT_EQ(actual->numel(), expect_value->numel()); - EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size()); - - for (int64_t i = 0; i < expect_value->numel(); ++i) { - EXPECT_EQ(expect_value->mutable_data(place)[i], - actual->mutable_data(place)[i]); - } - listen_and_serv_op->Stop(); - server_thread.join(); - listen_and_serv_op.reset(); - paddle::operators::ListenAndServOp::ResetPort(); + std::atomic initialized; + initialized = false; + std::thread server_thread(StartServerNet, true, &initialized); + while (!initialized) { + } + auto *listen_and_serv_op_ptr = + static_cast( + listen_and_serv_op.get()); + ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); + listen_and_serv_op_ptr->WaitServerReady(); + + // local net + f::Scope scope; + p::CPUPlace place; + p::CPUDeviceContext ctx(place); + InitSelectedRowsInScope(place, &scope); + scope.Var("RPC_CLIENT_VAR"); + f::AttributeMap attrs; + selected_port = listen_and_serv_op_ptr->GetSelectedPort(); + std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); + attrs.insert({"endpoints", std::vector({endpoint})}); + attrs.insert({"epmap", std::vector({endpoint})}); + auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}}, attrs); + send_op->Run(scope, place); + + auto x0 = scope.Var("x0")->GetMutable(); + auto x1 = scope.Var("x1")->GetMutable(); + auto out = scope.Var("Out")->GetMutable(); + auto actual = out->mutable_value(); + + std::unique_ptr expect{new f::SelectedRows()}; + auto expect_value = expect->mutable_value(); + expect_value->mutable_data(f::make_ddim({5, 10}), place); + + m::SelectedRowsAdd add_functor; + add_functor(ctx, *x0, *x1, expect.get()); + + EXPECT_EQ(actual->numel(), expect_value->numel()); + EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size()); + + for (int64_t i = 0; i < expect_value->numel(); ++i) { + EXPECT_EQ(expect_value->mutable_data(place)[i], + actual->mutable_data(place)[i]); + } + listen_and_serv_op->Stop(); + server_thread.join(); + listen_and_serv_op.reset(); + paddle::operators::ListenAndServOp::ResetPort(); }