|
|
|
@ -520,6 +520,11 @@ class DistributeTranspiler:
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
def get_train_startup_program(self, checkpoint_load_dir=None):
|
|
|
|
|
"""
|
|
|
|
|
Get train startup program.
|
|
|
|
|
If checkpoint_load_dir is None, rerurn default startup program.
|
|
|
|
|
IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var.
|
|
|
|
|
"""
|
|
|
|
|
startup_prog = default_startup_program()
|
|
|
|
|
|
|
|
|
|
if not checkpoint_load_dir:
|
|
|
|
@ -536,7 +541,10 @@ class DistributeTranspiler:
|
|
|
|
|
attrs={"dir": checkpoint_load_dir})
|
|
|
|
|
return startup_prog
|
|
|
|
|
|
|
|
|
|
def get_startup_program(self, endpoint, pserver_program):
|
|
|
|
|
def get_startup_program(self,
|
|
|
|
|
endpoint,
|
|
|
|
|
pserver_program,
|
|
|
|
|
checkpoint_load_dir=None):
|
|
|
|
|
"""
|
|
|
|
|
Get startup program for current parameter server.
|
|
|
|
|
Modify operator input variables if there are variables that
|
|
|
|
@ -561,6 +569,7 @@ class DistributeTranspiler:
|
|
|
|
|
created_var_map[var.name] = tmpvar
|
|
|
|
|
|
|
|
|
|
# 2. rename op outputs
|
|
|
|
|
load_vars = []
|
|
|
|
|
for op in orig_s_prog.global_block().ops:
|
|
|
|
|
new_inputs = dict()
|
|
|
|
|
new_outputs = dict()
|
|
|
|
@ -588,6 +597,16 @@ class DistributeTranspiler:
|
|
|
|
|
inputs=new_inputs,
|
|
|
|
|
outputs=new_outputs,
|
|
|
|
|
attrs=op.attrs)
|
|
|
|
|
for var in new_outputs.values():
|
|
|
|
|
load_vars.append(var.name)
|
|
|
|
|
# add checkpoint op
|
|
|
|
|
if not checkpoint_load_dir:
|
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
s_prog.global_block().append_op(
|
|
|
|
|
type="checkpoint_load",
|
|
|
|
|
inputs={"X": load_vars},
|
|
|
|
|
attrs={"dir": checkpoint_load_dir})
|
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
|