|
|
|
@ -315,10 +315,21 @@ class DistributeTranspiler:
|
|
|
|
|
"sync_mode": self.sync_mode
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
serial_var = program.global_block().create_var(
|
|
|
|
|
name="SERIAL_NUMBER",
|
|
|
|
|
persistable=True,
|
|
|
|
|
type=core.VarDesc.VarType.RAW)
|
|
|
|
|
|
|
|
|
|
save_vars = []
|
|
|
|
|
for var in self.origin_program.list_vars():
|
|
|
|
|
if self.is_persistable(var):
|
|
|
|
|
save_vars.append(var.name)
|
|
|
|
|
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="checkpoint_save",
|
|
|
|
|
inputs={"X": send_outputs},
|
|
|
|
|
attrs={"overwrite": True,
|
|
|
|
|
inputs={"X": save_vars},
|
|
|
|
|
outputs={"Serial": serial_var},
|
|
|
|
|
attrs={"overwrite": False,
|
|
|
|
|
"dir": "/workspace/ckpt/"})
|
|
|
|
|
|
|
|
|
|
# step4: Concat the parameters splits together after recv.
|
|
|
|
@ -501,6 +512,27 @@ class DistributeTranspiler:
|
|
|
|
|
pserver_program.sync_with_cpp()
|
|
|
|
|
return pserver_program
|
|
|
|
|
|
|
|
|
|
def is_persistable(self, var):
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW :
|
|
|
|
|
return False
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
def get_train_startup_program(self, checkpoint_load_dir=None):
|
|
|
|
|
startup_prog = default_startup_program()
|
|
|
|
|
|
|
|
|
|
if not checkpoint_load_dir:
|
|
|
|
|
return startup_prog
|
|
|
|
|
|
|
|
|
|
for var in startup_prog.list_vars():
|
|
|
|
|
if self.is_persistable(var):
|
|
|
|
|
print("var: %s" % var.name)
|
|
|
|
|
|
|
|
|
|
startup_prog.global_block().append_op(
|
|
|
|
|
type="checkpoint_load", attrs={"dir": checkpoint_load_dir})
|
|
|
|
|
return startup_prog
|
|
|
|
|
|
|
|
|
|
def get_startup_program(self, endpoint, pserver_program):
|
|
|
|
|
"""
|
|
|
|
|
Get startup program for current parameter server.
|
|
|
|
|