test=develop, optimize geo communicator (#26857)

* test=develop, optimize geo communicator
my_2.0rc
123malin 4 years ago committed by GitHub
parent 5132f5129d
commit cc780b1977
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -284,7 +284,7 @@ class AsyncCommunicator : public Communicator {
void InitParams(); void InitParams();
void MainThread(); virtual void MainThread();
void Send(const std::vector<std::string> &var_names, void Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables, const std::vector<std::string> &var_tables,
@ -408,7 +408,7 @@ class GeoCommunicator : public AsyncCommunicator {
void InitImpl(const RpcCtxMap &send_varname_to_ctx, void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override; Scope *recv_scope) override;
void MainThread() override;
void InitEnvs() { void InitEnvs() {
min_send_grad_num_before_recv_ = 0; min_send_grad_num_before_recv_ = 0;
@ -426,9 +426,12 @@ class GeoCommunicator : public AsyncCommunicator {
const std::vector<std::string> &var_tables, const std::vector<std::string> &var_tables,
const framework::Scope &scope) override; const framework::Scope &scope) override;
void SendByCommunicator(int batches) override; void SendByCommunicator(int batches) { return; }
std::vector<int64_t> MergeSparseIds(const std::string &send_varname);
void SendSparse(const std::string &varname, int batches); void SendSparse(const std::string &varname, int ep_idx,
const std::vector<int64_t> &sparse_ids);
void SendDense(const std::string &varname); void SendDense(const std::string &varname);
@ -436,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator {
void RecvByCommunicator() override; void RecvByCommunicator() override;
void RecvSparse(const std::string &varname); void RecvSparse(const std::string &varname, int ep_idx);
void RecvDense(const std::string &varname); void RecvDense(const std::string &varname);
@ -459,11 +462,13 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver // parameter on pserver
std::shared_ptr<Scope> pserver_scope_; std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<std::string, int send_var_nums_ = 0;
std::shared_ptr<BlockingQueue<std::vector<int64_t>>>>
send_ids_to_queue_;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_; std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::unordered_map<
std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
sparse_id_queues_;
}; };
} // namespace distributed } // namespace distributed

@ -144,7 +144,8 @@ CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
void CPUDeviceContext::InitPoolDevice() { void CPUDeviceContext::InitPoolDevice() {
using EigenEnv = Eigen::StlThreadEnvironment; using EigenEnv = Eigen::StlThreadEnvironment;
using EigenThreadPool = Eigen::ThreadPoolTempl<EigenEnv>; using EigenThreadPool = Eigen::ThreadPoolTempl<EigenEnv>;
int num_threads = std::thread::hardware_concurrency(); // int num_threads = std::thread::hardware_concurrency();
int num_threads = 1;
eigen_threadpool_.reset(new EigenThreadPool(num_threads)); eigen_threadpool_.reset(new EigenThreadPool(num_threads));
eigen_pool_device_.reset( eigen_pool_device_.reset(
new Eigen::ThreadPoolDevice(eigen_threadpool_.get(), num_threads)); new Eigen::ThreadPoolDevice(eigen_threadpool_.get(), num_threads));

@ -169,7 +169,7 @@ def append_send_ops_pass(program, config):
trainer_id = config.get_role_id() trainer_id = config.get_role_id()
pserver_endpoints = config.get_ps_endpoints() pserver_endpoints = config.get_ps_endpoints()
def _append_send_op(union_vars, queue): def _append_grad_send_op(union_vars, queue):
if queue == STEP_COUNTER: if queue == STEP_COUNTER:
send_input_vars = [] send_input_vars = []
@ -198,6 +198,43 @@ def append_send_ops_pass(program, config):
return dummy_output return dummy_output
def _append_sparse_ids_send_op():
sparse_var = []
sparse_tables = []
unique_sparse_var = {}
for op in program.global_block().ops:
if "is_sparse" in op.all_attrs():
if op.type == "lookup_table":
op._set_attr('remote_prefetch', False)
for input_var_name, sparse_var_name in zip(
op.input("Ids"), op.input("W")):
if input_var_name in unique_sparse_var:
if unique_sparse_var[input_var_name] == sparse_var_name:
continue
input_var = program.global_block().var(input_var_name)
sparse_var.append(input_var)
sparse_tables.append(sparse_var_name)
unique_sparse_var[input_var_name] = sparse_var_name
dummy_output = []
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
program.global_block().append_op(
type="send",
inputs={"X": sparse_var},
outputs={"Out": dummy_output},
attrs={
"send_varnames": sparse_tables,
"merge_add": True,
"use_send_handler": False,
"endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return dummy_output
def _append_barrier_op(dummys): def _append_barrier_op(dummys):
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
@ -214,8 +251,12 @@ def append_send_ops_pass(program, config):
sends = config.get_trainer_send_context() sends = config.get_trainer_send_context()
for merged_name, send in sends.items(): if mode == DistributedMode.GEO:
dummys.append(_append_send_op(send.origin_varnames(), merged_name)) dummys.append(_append_sparse_ids_send_op())
else:
for merged_name, send in sends.items():
dummys.append(
_append_grad_send_op(send.origin_varnames(), merged_name))
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
_append_barrier_op(dummys) _append_barrier_op(dummys)

@ -27,6 +27,8 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
paddle.enable_static()
class TestCommunicator(unittest.TestCase): class TestCommunicator(unittest.TestCase):
def net(self): def net(self):

Loading…
Cancel
Save