|
|
@ -247,10 +247,11 @@ class DistributeTranspiler(object):
|
|
|
|
return sparse_update_ops
|
|
|
|
return sparse_update_ops
|
|
|
|
|
|
|
|
|
|
|
|
def _update_remote_sparse_update_op(self, param_varname, height_sections,
|
|
|
|
def _update_remote_sparse_update_op(self, param_varname, height_sections,
|
|
|
|
endpint_map):
|
|
|
|
endpint_map, table_names):
|
|
|
|
for op in self.sparse_update_ops:
|
|
|
|
for op in self.sparse_update_ops:
|
|
|
|
if param_varname in op.input_arg_names:
|
|
|
|
if param_varname in op.input_arg_names:
|
|
|
|
op._set_attr('epmap', endpint_map)
|
|
|
|
op._set_attr('epmap', endpint_map)
|
|
|
|
|
|
|
|
op._set_attr('table_names', table_names)
|
|
|
|
op._set_attr('height_sections', height_sections)
|
|
|
|
op._set_attr('height_sections', height_sections)
|
|
|
|
op._set_attr('trainer_id', self.trainer_id)
|
|
|
|
op._set_attr('trainer_id', self.trainer_id)
|
|
|
|
|
|
|
|
|
|
|
@ -326,6 +327,7 @@ class DistributeTranspiler(object):
|
|
|
|
# get all sparse update ops
|
|
|
|
# get all sparse update ops
|
|
|
|
self.sparse_update_ops = self._get_all_remote_sparse_update_op(
|
|
|
|
self.sparse_update_ops = self._get_all_remote_sparse_update_op(
|
|
|
|
self.origin_program)
|
|
|
|
self.origin_program)
|
|
|
|
|
|
|
|
# use_sparse_update_param_name -> split_height_section
|
|
|
|
self.sparse_param_to_height_sections = dict()
|
|
|
|
self.sparse_param_to_height_sections = dict()
|
|
|
|
|
|
|
|
|
|
|
|
# add distributed attrs to program
|
|
|
|
# add distributed attrs to program
|
|
|
@ -365,6 +367,13 @@ class DistributeTranspiler(object):
|
|
|
|
splited_grad_varname = splited_vars[0].name
|
|
|
|
splited_grad_varname = splited_vars[0].name
|
|
|
|
index = find_op_by_output_arg(
|
|
|
|
index = find_op_by_output_arg(
|
|
|
|
program.global_block(), splited_grad_varname, reverse=True)
|
|
|
|
program.global_block(), splited_grad_varname, reverse=True)
|
|
|
|
|
|
|
|
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
|
|
|
|
|
|
|
|
sparse_param_name = self.grad_name_to_param_name[
|
|
|
|
|
|
|
|
splited_grad_varname]
|
|
|
|
|
|
|
|
if self._is_input_of_remote_sparse_update_op(
|
|
|
|
|
|
|
|
sparse_param_name):
|
|
|
|
|
|
|
|
self.sparse_param_to_height_sections[
|
|
|
|
|
|
|
|
sparse_param_name] = [splited_vars[0].shape[0]]
|
|
|
|
elif len(splited_vars) > 1:
|
|
|
|
elif len(splited_vars) > 1:
|
|
|
|
orig_var = program.global_block().vars[splited_grad_varname]
|
|
|
|
orig_var = program.global_block().vars[splited_grad_varname]
|
|
|
|
index = find_op_by_output_arg(
|
|
|
|
index = find_op_by_output_arg(
|
|
|
@ -435,9 +444,11 @@ class DistributeTranspiler(object):
|
|
|
|
all_recv_outputs = []
|
|
|
|
all_recv_outputs = []
|
|
|
|
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
eps = []
|
|
|
|
eps = []
|
|
|
|
|
|
|
|
table_names = []
|
|
|
|
for var in splited_var:
|
|
|
|
for var in splited_var:
|
|
|
|
index = [v.name for v in recv_vars].index(var.name)
|
|
|
|
index = [v.name for v in recv_vars].index(var.name)
|
|
|
|
eps.append(eplist[index])
|
|
|
|
eps.append(eplist[index])
|
|
|
|
|
|
|
|
table_names.append(var.name)
|
|
|
|
if self.sync_mode:
|
|
|
|
if self.sync_mode:
|
|
|
|
recv_dep_in = send_barrier_out
|
|
|
|
recv_dep_in = send_barrier_out
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -457,8 +468,8 @@ class DistributeTranspiler(object):
|
|
|
|
if param_varname in self.sparse_param_to_height_sections:
|
|
|
|
if param_varname in self.sparse_param_to_height_sections:
|
|
|
|
height_sections = self.sparse_param_to_height_sections[
|
|
|
|
height_sections = self.sparse_param_to_height_sections[
|
|
|
|
param_varname]
|
|
|
|
param_varname]
|
|
|
|
self._update_remote_sparse_update_op(param_varname,
|
|
|
|
self._update_remote_sparse_update_op(
|
|
|
|
height_sections, eps)
|
|
|
|
param_varname, height_sections, eps, table_names)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
all_recv_outputs.extend(splited_var)
|
|
|
|
all_recv_outputs.extend(splited_var)
|
|
|
|
program.global_block().append_op(
|
|
|
|
program.global_block().append_op(
|
|
|
|