|
|
|
@ -360,7 +360,6 @@ class Trainer(object):
|
|
|
|
|
self.train_program = t.get_pserver_program(current_endpoint)
|
|
|
|
|
self.startup_program = t.get_startup_program(current_endpoint,
|
|
|
|
|
self.train_program)
|
|
|
|
|
self.slice_vars = t.get_slice_vars_and_atts(current_endpoint)
|
|
|
|
|
elif training_role == "TRAINER":
|
|
|
|
|
self.train_program = t.get_trainer_program()
|
|
|
|
|
else:
|
|
|
|
@ -609,16 +608,6 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
# Pserver Load
|
|
|
|
|
else:
|
|
|
|
|
# load slice_vars
|
|
|
|
|
if self.slice_vars != None and len(self.slice_vars) != 0:
|
|
|
|
|
load_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=checkpoint_dir,
|
|
|
|
|
main_program=self.startup_program,
|
|
|
|
|
role_id=self.checkpoint_cfg.pserver_id,
|
|
|
|
|
is_trainer=False,
|
|
|
|
|
load_slice_up_vars=self.slice_vars)
|
|
|
|
|
|
|
|
|
|
# load lookup table
|
|
|
|
|
if self.checkpoint_cfg.lookup_table_name:
|
|
|
|
|
load_checkpoint(
|
|
|
|
@ -766,7 +755,6 @@ def load_checkpoint(executor,
|
|
|
|
|
is_trainer=True,
|
|
|
|
|
load_models=False,
|
|
|
|
|
load_trainer_args=None,
|
|
|
|
|
load_slice_up_vars=None,
|
|
|
|
|
load_lookup_table=None):
|
|
|
|
|
"""
|
|
|
|
|
This function filters out all checkpoint variables from the give
|
|
|
|
@ -827,18 +815,11 @@ def load_checkpoint(executor,
|
|
|
|
|
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
|
|
|
|
|
return
|
|
|
|
|
if load_trainer_args:
|
|
|
|
|
|
|
|
|
|
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
|
|
|
|
|
format(checkpoint_dir, role_id, load_trainer_args))
|
|
|
|
|
|
|
|
|
|
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
|
|
|
|
|
load_trainer_args)
|
|
|
|
|
return trainer_args_ret
|
|
|
|
|
# pserver load
|
|
|
|
|
else:
|
|
|
|
|
if load_slice_up_vars:
|
|
|
|
|
_load_slice_up_vars(executor, checkpoint_dir, load_slice_up_vars)
|
|
|
|
|
return
|
|
|
|
|
if load_lookup_table:
|
|
|
|
|
_load_lookup_table_vars(executor, checkpoint_dir, main_program,
|
|
|
|
|
role_id, load_lookup_table)
|
|
|
|
@ -911,51 +892,6 @@ def _load_persistable_vars(executor, dirname, program, has_model_dir=False):
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_slice_up_vars(executor, dirname, slice_vars):
|
|
|
|
|
if slice_vars == None or len(slice_vars) == 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
dirname = _get_model_dir(dirname)
|
|
|
|
|
|
|
|
|
|
load_prog = framework.Program()
|
|
|
|
|
load_block = load_prog.global_block()
|
|
|
|
|
|
|
|
|
|
for var_tuple in slice_vars:
|
|
|
|
|
orig_var = var_tuple[0]
|
|
|
|
|
start = var_tuple[1]
|
|
|
|
|
slice_var = var_tuple[2]
|
|
|
|
|
end = start + reduce(lambda x, y: x * y, slice_var.shape)
|
|
|
|
|
|
|
|
|
|
clone_orig_var = load_block.create_var(
|
|
|
|
|
name=orig_var.name,
|
|
|
|
|
type=orig_var.type,
|
|
|
|
|
shape=orig_var.shape,
|
|
|
|
|
dtype=orig_var.dtype,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
clone_slice_var = load_block.create_var(
|
|
|
|
|
name=slice_var.name,
|
|
|
|
|
type=slice_var.type,
|
|
|
|
|
shape=slice_var.shape,
|
|
|
|
|
dtype=slice_var.dtype,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
load_block.append_op(
|
|
|
|
|
type='load',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={'Out': [clone_orig_var]},
|
|
|
|
|
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
|
|
|
|
|
load_block.append_op(
|
|
|
|
|
type="slice",
|
|
|
|
|
inputs={'Input': clone_orig_var},
|
|
|
|
|
outputs={'Out': clone_slice_var},
|
|
|
|
|
attrs={'axes': [0],
|
|
|
|
|
'starts': [start],
|
|
|
|
|
'ends': [end]})
|
|
|
|
|
|
|
|
|
|
executor.run(load_prog)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
|
|
|
|
|
"""
|
|
|
|
|
The parameter server will load lookup table's local file in
|
|
|
|
|