|
|
|
@ -793,6 +793,8 @@ class DistributeTranspiler(object):
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
fetch_barrier_input.extend(splited_var)
|
|
|
|
|
|
|
|
|
|
self._update_remote_sparse_update_op(program, need_sparse_update_params)
|
|
|
|
|
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
# form a WAW dependency
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
@ -806,11 +808,10 @@ class DistributeTranspiler(object):
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
orig_param = program.global_block().vars[param_varname]
|
|
|
|
|
if param_varname not in self.sparse_param_to_height_sections:
|
|
|
|
|
if not self.config.runtime_split_send_recv:
|
|
|
|
|
if len(splited_var
|
|
|
|
|
) > 1 and not self.config.runtime_split_send_recv:
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={"X": splited_var},
|
|
|
|
@ -820,8 +821,6 @@ class DistributeTranspiler(object):
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
self._update_remote_sparse_update_op(program,
|
|
|
|
|
need_sparse_update_params)
|
|
|
|
|
if not self.sync_mode:
|
|
|
|
|
lr_ops = self._get_lr_ops()
|
|
|
|
|
if len(lr_ops) > 0:
|
|
|
|
|