|
|
|
|
@ -525,12 +525,15 @@ class DistributeTranspiler:
|
|
|
|
|
if not checkpoint_load_dir:
|
|
|
|
|
return startup_prog
|
|
|
|
|
|
|
|
|
|
load_vars = []
|
|
|
|
|
for var in startup_prog.list_vars():
|
|
|
|
|
if self.is_persistable(var):
|
|
|
|
|
print("var: %s" % var.name)
|
|
|
|
|
load_vars.append(var.name)
|
|
|
|
|
|
|
|
|
|
startup_prog.global_block().append_op(
|
|
|
|
|
type="checkpoint_load", attrs={"dir": checkpoint_load_dir})
|
|
|
|
|
type="checkpoint_load",
|
|
|
|
|
outputs={"Out": load_vars},
|
|
|
|
|
attrs={"dir": checkpoint_load_dir})
|
|
|
|
|
return startup_prog
|
|
|
|
|
|
|
|
|
|
def get_startup_program(self, endpoint, pserver_program):
|
|
|
|
|
|