|
|
|
@ -483,11 +483,11 @@ def save_checkpoint(executor,
|
|
|
|
|
:param main_program
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
if checkpoint_dir.strip() is None:
|
|
|
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
|
|
|
|
|
|
if trainer_args and not isinstance(trainer_args, dict):
|
|
|
|
|
raise TypeError("The type of 'trainer_args' should be dict")
|
|
|
|
|
if trainer_args:
|
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
@ -514,11 +514,11 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
if checkpoint_dir.strip() is None:
|
|
|
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
|
|
|
|
|
|
if serial is None or serial < 0:
|
|
|
|
|
raise ValueError("The values of 'serial' should not be None or <0 ")
|
|
|
|
|
raise ValueError("'serial' should not be None or <0 ")
|
|
|
|
|
|
|
|
|
|
if main_program is None:
|
|
|
|
|
raise ValueError('main_program should not be None.')
|
|
|
|
@ -536,8 +536,8 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
:param delete_dir
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
if checkpoint_dir.strip() is None:
|
|
|
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
|
|
|
|
|
|
|
|
|
|
if delete_dir and not os.listdir(checkpoint_dir):
|
|
|
|
@ -590,8 +590,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_trainer_args(dirname, trainer_id, trainer_args):
|
|
|
|
|
if not isinstance(trainer_args, dict):
|
|
|
|
|
raise TypeError("The type of 'trainer_args' should be dict")
|
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_trainer_dir(dirname, trainer_id)
|
|
|
|
|
|
|
|
|
|
for name, value in trainer_args.iteritems():
|
|
|
|
@ -602,12 +602,11 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
|
|
|
|
|
assert isinstance(trainer_args, list)
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
|
|
|
|
|
|
|
|
|
|
if not isinstance(trainer_args, list):
|
|
|
|
|
raise TypeError("The type of 'trainer_args' should be list")
|
|
|
|
|
|
|
|
|
|
ret_values = []
|
|
|
|
|
|
|
|
|
|
for arg in trainer_args:
|
|
|
|
|