update pserver startup

shanyi15-patch-3
tangwei12 8 years ago
parent 3dd274657f
commit 4220b31d4f

@ -520,6 +520,11 @@ class DistributeTranspiler:
return var.persistable
def get_train_startup_program(self, checkpoint_load_dir=None):
"""
Get train startup program.
If checkpoint_load_dir is None, rerurn default startup program.
IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var.
"""
startup_prog = default_startup_program()
if not checkpoint_load_dir:
@ -536,7 +541,10 @@ class DistributeTranspiler:
attrs={"dir": checkpoint_load_dir})
return startup_prog
def get_startup_program(self, endpoint, pserver_program):
def get_startup_program(self,
endpoint,
pserver_program,
checkpoint_load_dir=None):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
@ -561,6 +569,7 @@ class DistributeTranspiler:
created_var_map[var.name] = tmpvar
# 2. rename op outputs
load_vars = []
for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict()
@ -588,6 +597,16 @@ class DistributeTranspiler:
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
for var in new_outputs.values():
load_vars.append(var.name)
# add checkpoint op
if not checkpoint_load_dir:
return s_prog
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
attrs={"dir": checkpoint_load_dir})
return s_prog
# transpiler function for dis lookup_table

Loading…
Cancel
Save