modify get trainer param

shanyi15-patch-3
tangwei12 8 years ago
parent 886897ccf7
commit 9cf47afe61

@ -525,12 +525,15 @@ class DistributeTranspiler:
if not checkpoint_load_dir:
return startup_prog
load_vars = []
for var in startup_prog.list_vars():
if self.is_persistable(var):
print("var: %s" % var.name)
load_vars.append(var.name)
startup_prog.global_block().append_op(
type="checkpoint_load", attrs={"dir": checkpoint_load_dir})
type="checkpoint_load",
outputs={"Out": load_vars},
attrs={"dir": checkpoint_load_dir})
return startup_prog
def get_startup_program(self, endpoint, pserver_program):

Loading…
Cancel
Save