fix fetch handler error with pslib (#20679)

* fix fetch handler error with pslib
* fix distributed lookup table op with 1 pserver
revert-20712-fix_depthwise_conv
tangwei12 5 years ago committed by GitHub
parent 78431dc7bc
commit 1d925440ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -998,18 +998,6 @@ class Executor(object):
if fetch_handler is not None:
fetch_instance = fetch_handler
elif fetch_handler is None and fetch_list is not None:
class FH(FetchHandler):
def handler(self, fetch_target_vars):
for i in range(len(fetch_target_vars)):
print("{}: \n {}\n".format(fetch_info[i],
fetch_target_vars[i]))
fetch_target_names = [var.name for var in fetch_list]
fetch_instance = FH(fetch_target_names,
period_secs=print_period,
return_np=False)
else:
fetch_instance = FetchHandler([])
@ -1018,7 +1006,10 @@ class Executor(object):
dataset=dataset,
scope=scope,
thread=thread,
debug=debug)
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer._set_infer(is_infer)
trainer._gen_trainer_desc()

@ -793,6 +793,8 @@ class DistributeTranspiler(object):
if self.sync_mode:
fetch_barrier_input.extend(splited_var)
self._update_remote_sparse_update_op(program, need_sparse_update_params)
if self.sync_mode:
# form a WAW dependency
program.global_block().append_op(
@ -806,11 +808,10 @@ class DistributeTranspiler(object):
})
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections:
if not self.config.runtime_split_send_recv:
if len(splited_var
) > 1 and not self.config.runtime_split_send_recv:
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
@ -820,8 +821,6 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
self._update_remote_sparse_update_op(program,
need_sparse_update_params)
if not self.sync_mode:
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:

Loading…
Cancel
Save