|
|
|
@ -34,6 +34,7 @@ import math
|
|
|
|
|
import random
|
|
|
|
|
import numpy as np
|
|
|
|
|
import collections
|
|
|
|
|
import six
|
|
|
|
|
|
|
|
|
|
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
|
|
|
|
|
from .. import core, framework
|
|
|
|
@ -210,6 +211,9 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
|
|
|
|
|
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
|
|
|
|
|
self.param_name_to_grad_name = dict()
|
|
|
|
|
for param_var, grad_var in self.params_grads:
|
|
|
|
|
self.param_name_to_grad_name[param_var.name] = grad_var.name
|
|
|
|
|
|
|
|
|
|
# add distributed attrs to program
|
|
|
|
|
self.origin_program._is_distributed = True
|
|
|
|
@ -236,34 +240,39 @@ class DistributeTranspiler(object):
|
|
|
|
|
random.seed(self.origin_program.random_seed)
|
|
|
|
|
random.shuffle(grad_var_mapping_items)
|
|
|
|
|
|
|
|
|
|
for orig_varname, splited_vars in grad_var_mapping_items:
|
|
|
|
|
grad_name_to_send_dummy_out = dict()
|
|
|
|
|
for grad_varname, splited_vars in grad_var_mapping_items:
|
|
|
|
|
eplist = ps_dispatcher.dispatch(splited_vars)
|
|
|
|
|
|
|
|
|
|
if not self.config.slice_var_up:
|
|
|
|
|
assert (len(splited_vars) == 1)
|
|
|
|
|
|
|
|
|
|
splited_grad_varname = grad_varname
|
|
|
|
|
if len(splited_vars) == 1:
|
|
|
|
|
orig_varname = splited_vars[0].name
|
|
|
|
|
splited_grad_varname = splited_vars[0].name
|
|
|
|
|
index = find_op_by_output_arg(program.global_block(),
|
|
|
|
|
orig_varname)
|
|
|
|
|
splited_grad_varname)
|
|
|
|
|
elif len(splited_vars) > 1:
|
|
|
|
|
orig_var = program.global_block().vars[orig_varname]
|
|
|
|
|
orig_var = program.global_block().vars[splited_grad_varname]
|
|
|
|
|
index = find_op_by_output_arg(program.global_block(),
|
|
|
|
|
orig_varname)
|
|
|
|
|
splited_grad_varname)
|
|
|
|
|
self._insert_split_op(program, orig_var, index, splited_vars)
|
|
|
|
|
index += 1
|
|
|
|
|
else:
|
|
|
|
|
AssertionError("Can not insert the send op by original "
|
|
|
|
|
"variable name :", orig_varname)
|
|
|
|
|
"variable name :", splited_grad_varname)
|
|
|
|
|
|
|
|
|
|
dummy_output = program.global_block().create_var()
|
|
|
|
|
grad_name_to_send_dummy_out[grad_varname] = dummy_output
|
|
|
|
|
program.global_block()._insert_op(
|
|
|
|
|
index=index + 1,
|
|
|
|
|
type="send",
|
|
|
|
|
inputs={"X": splited_vars},
|
|
|
|
|
outputs={},
|
|
|
|
|
outputs={"Out": dummy_output},
|
|
|
|
|
attrs={
|
|
|
|
|
"epmap": eplist,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
|
|
|
|
|
"sync_mode": not self.sync_mode,
|
|
|
|
|
})
|
|
|
|
|
for _, var in enumerate(splited_vars):
|
|
|
|
|
send_vars.append(var)
|
|
|
|
@ -275,7 +284,6 @@ class DistributeTranspiler(object):
|
|
|
|
|
outputs={},
|
|
|
|
|
attrs={
|
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|
|
"sync_mode": self.sync_mode,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
@ -291,19 +299,21 @@ class DistributeTranspiler(object):
|
|
|
|
|
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
|
|
|
|
|
|
|
|
|
|
# step4: Concat the parameters splits together after recv.
|
|
|
|
|
for varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
|
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
|
eps = []
|
|
|
|
|
for var in splited_var:
|
|
|
|
|
index = [v.name for v in recv_vars].index(var.name)
|
|
|
|
|
eps.append(eplist[index])
|
|
|
|
|
|
|
|
|
|
grad_send_dummy_out = grad_name_to_send_dummy_out[
|
|
|
|
|
self.param_name_to_grad_name[param_varname]]
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="recv",
|
|
|
|
|
inputs={},
|
|
|
|
|
inputs={"X": [grad_send_dummy_out]},
|
|
|
|
|
outputs={"Out": splited_var},
|
|
|
|
|
attrs={
|
|
|
|
|
"epmap": eps,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
|
|
|
|
|
"sync_mode": not self.sync_mode
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if self.sync_mode:
|
|
|
|
@ -316,10 +326,10 @@ class DistributeTranspiler(object):
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
for varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
|
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
orig_param = program.global_block().vars[varname]
|
|
|
|
|
orig_param = program.global_block().vars[param_varname]
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={"X": splited_var},
|
|
|
|
@ -387,7 +397,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
op = startup_program.global_block().append_op(
|
|
|
|
|
type="recv",
|
|
|
|
|
inputs={},
|
|
|
|
|
inputs={"X": []},
|
|
|
|
|
outputs={"Out": splited_var},
|
|
|
|
|
attrs={
|
|
|
|
|
"epmap": eps,
|
|
|
|
@ -826,19 +836,21 @@ class DistributeTranspiler(object):
|
|
|
|
|
self.config.min_block_size)
|
|
|
|
|
assert (len(grad_blocks) == len(param_blocks))
|
|
|
|
|
|
|
|
|
|
# origin_varname -> [splited_var]
|
|
|
|
|
# origin_param_name -> [splited_param_vars]
|
|
|
|
|
self.param_var_mapping = self._create_vars_from_blocklist(
|
|
|
|
|
self.origin_program, param_blocks)
|
|
|
|
|
# origin_grad_name -> [splited_grad_vars]
|
|
|
|
|
self.grad_var_mapping = self._create_vars_from_blocklist(
|
|
|
|
|
self.origin_program,
|
|
|
|
|
grad_blocks,
|
|
|
|
|
add_trainer_suffix=self.trainer_num > 1)
|
|
|
|
|
# dict(grad_splited_var -> param_splited_var)
|
|
|
|
|
self.grad_param_mapping = collections.OrderedDict()
|
|
|
|
|
for g, p in zip(grad_blocks, param_blocks):
|
|
|
|
|
g_name, g_bid, _ = g.split(":")
|
|
|
|
|
p_name, p_bid, _ = p.split(":")
|
|
|
|
|
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
|
|
|
|
|
self.param_var_mapping[p_name][int(p_bid)]
|
|
|
|
|
self.param_var_mapping[p_name][int(p_bid)]
|
|
|
|
|
|
|
|
|
|
# create mapping of endpoint -> split var to create pserver side program
|
|
|
|
|
self.param_grad_ep_mapping = collections.OrderedDict()
|
|
|
|
@ -959,7 +971,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
index=op_index + 2,
|
|
|
|
|
type="send",
|
|
|
|
|
inputs={'X': self.trainer_side_table_grad_list},
|
|
|
|
|
outputs={},
|
|
|
|
|
outputs={'Out': []},
|
|
|
|
|
attrs={
|
|
|
|
|
"sync_mode": True,
|
|
|
|
|
"epmap": pserver_endpoints,
|
|
|
|
|