|
|
|
@ -218,7 +218,8 @@ class DistributeTranspiler(object):
|
|
|
|
|
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
|
|
|
|
|
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
|
|
|
|
|
# shuffle the map will avoid the uneven distribution above
|
|
|
|
|
grad_var_mapping_items = list(self.grad_var_mapping.items())
|
|
|
|
|
grad_var_mapping_items = list(
|
|
|
|
|
six.moves.iteritems(self.grad_var_mapping))
|
|
|
|
|
|
|
|
|
|
if not self.config.slice_var_up:
|
|
|
|
|
random.seed(self.origin_program.random_seed)
|
|
|
|
@ -279,7 +280,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
|
|
|
|
|
|
|
|
|
|
# step4: Concat the parameters splits together after recv.
|
|
|
|
|
for varname, splited_var in list(self.param_var_mapping.items()):
|
|
|
|
|
for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
|
|
|
|
|
eps = []
|
|
|
|
|
for var in splited_var:
|
|
|
|
|
index = [v.name for v in recv_vars].index(var.name)
|
|
|
|
@ -303,7 +304,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
for varname, splited_var in list(self.param_var_mapping.items()):
|
|
|
|
|
for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
|
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
orig_param = program.global_block().vars[varname]
|
|
|
|
@ -560,7 +561,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
# 1. create vars in pserver program to startup program
|
|
|
|
|
pserver_vars = pserver_program.global_block().vars
|
|
|
|
|
created_var_map = collections.OrderedDict()
|
|
|
|
|
for _, var in list(pserver_vars.items()):
|
|
|
|
|
for _, var in six.moves.iteritems(pserver_vars):
|
|
|
|
|
tmpvar = s_prog.global_block()._clone_variable(var)
|
|
|
|
|
created_var_map[var.name] = tmpvar
|
|
|
|
|
|
|
|
|
@ -997,7 +998,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
block_map[varname] = []
|
|
|
|
|
block_map[varname].append((int(offset), int(size)))
|
|
|
|
|
|
|
|
|
|
for varname, splited in list(block_map.items()):
|
|
|
|
|
for varname, splited in six.moves.iteritems(block_map):
|
|
|
|
|
orig_var = program.global_block().var(varname)
|
|
|
|
|
if len(splited) == 1:
|
|
|
|
|
if self.sync_mode and add_trainer_suffix:
|
|
|
|
@ -1248,9 +1249,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
def _is_splited_grad_var(self, var, var_dict):
|
|
|
|
|
grad_block = None
|
|
|
|
|
# TODO(minqiyang): replace these items() with six.iteritems() to
|
|
|
|
|
# improve memory
|
|
|
|
|
for _, g in list(var_dict.items()):
|
|
|
|
|
for _, g in six.moves.iteritems(var_dict):
|
|
|
|
|
if self._orig_varname(g.name) == self._orig_varname(var.name):
|
|
|
|
|
if g.name.find(".trainer_") == -1:
|
|
|
|
|
grad_block = g
|
|
|
|
@ -1260,7 +1259,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
def _clone_lr_op(self, program, block, op):
|
|
|
|
|
inputs = self._get_input_map_from_op(
|
|
|
|
|
self.origin_program.global_block().vars, op)
|
|
|
|
|
for key, varlist in list(inputs.items()):
|
|
|
|
|
for key, varlist in six.moves.iteritems(inputs):
|
|
|
|
|
if not isinstance(varlist, list):
|
|
|
|
|
varlist = [varlist]
|
|
|
|
|
for var in varlist:
|
|
|
|
@ -1269,7 +1268,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
outputs = self._get_output_map_from_op(
|
|
|
|
|
self.origin_program.global_block().vars, op)
|
|
|
|
|
for key, varlist in list(outputs.items()):
|
|
|
|
|
for key, varlist in six.moves.iteritems(outputs):
|
|
|
|
|
if not isinstance(varlist, list):
|
|
|
|
|
varlist = [varlist]
|
|
|
|
|
for var in varlist:
|
|
|
|
@ -1284,7 +1283,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
# Append the ops for parameters that do not need to be optimized/updated
|
|
|
|
|
inputs = self._get_input_map_from_op(
|
|
|
|
|
self.origin_program.global_block().vars, opt_op)
|
|
|
|
|
for key, varlist in list(inputs.items()):
|
|
|
|
|
for key, varlist in six.moves.iteritems(inputs):
|
|
|
|
|
if not isinstance(varlist, list):
|
|
|
|
|
varlist = [varlist]
|
|
|
|
|
for var in varlist:
|
|
|
|
@ -1303,7 +1302,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
outputs = self._get_output_map_from_op(
|
|
|
|
|
self.origin_program.global_block().vars, opt_op)
|
|
|
|
|
for key, varlist in list(outputs.items()):
|
|
|
|
|
for key, varlist in six.moves.iteritems(outputs):
|
|
|
|
|
if not isinstance(varlist, list):
|
|
|
|
|
varlist = [varlist]
|
|
|
|
|
for var in varlist:
|
|
|
|
|