|
|
|
@ -454,3 +454,90 @@ def get_parameter_value_by_name(name, executor, program=None):
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
var = program.global_block().var(name)
|
|
|
|
|
return get_parameter_value(var, executor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUCCESS = "_SUCCESS"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
|
dirname,
|
|
|
|
|
keep_max=10,
|
|
|
|
|
save_secs=600,
|
|
|
|
|
main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Save Variables to Checkpint Dir
|
|
|
|
|
|
|
|
|
|
:param dirname
|
|
|
|
|
:param keep_max
|
|
|
|
|
:param save_secs
|
|
|
|
|
"""
|
|
|
|
|
if dirname is None:
|
|
|
|
|
raise Exception("save checkpoint dir can not be none")
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
os.makedirs(dirname)
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(dirname) + 1
|
|
|
|
|
|
|
|
|
|
cur_dir = os.path.join(dirname, serial)
|
|
|
|
|
save_persistables(executor, cur_dir, main_program)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def restore_checkpoint(dirname, executor, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Load Variables from Checkpint Dir
|
|
|
|
|
|
|
|
|
|
:param dir
|
|
|
|
|
"""
|
|
|
|
|
if dirname is None and os.path.isdir(dirname):
|
|
|
|
|
raise Exception("restore checkpoint can not load variables from %s" %
|
|
|
|
|
dirname)
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(dirname) + 1
|
|
|
|
|
|
|
|
|
|
if serial < -1:
|
|
|
|
|
return
|
|
|
|
|
cur_dir = os.path.join(dirname, serial)
|
|
|
|
|
load_persistables(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _write_success(dirname):
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
success_file = os.path.join(dirname, SUCCESS)
|
|
|
|
|
with open(success_file, 'a'):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
get the biggest number in checkpoint_dir, which has _SUCCESS
|
|
|
|
|
"""
|
|
|
|
|
if not checkpoint_dir.strip():
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
def has_success(checkpoint_dir, cur_dir):
|
|
|
|
|
"""
|
|
|
|
|
is _SUCCESS in this dir
|
|
|
|
|
"""
|
|
|
|
|
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
int(cur_dir)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
|
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
|
return int(cur_dir)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
current_dir = -1
|
|
|
|
|
dirs = os.listdir(checkpoint_dir)
|
|
|
|
|
for cur_dir in dirs:
|
|
|
|
|
success_num = has_success(checkpoint_dir, cur_dir)
|
|
|
|
|
if success_num > current_dir:
|
|
|
|
|
current_dir = success_num
|
|
|
|
|
return current_dir
|
|
|
|
|