|
|
|
@ -456,40 +456,18 @@ def get_parameter_value_by_name(name, executor, program=None):
|
|
|
|
|
return get_parameter_value(var, executor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
"""
|
|
|
|
|
load_persist_vars_without_grad will load variables from a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be loaded.
|
|
|
|
|
"""
|
|
|
|
|
load_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
"""
|
|
|
|
|
save_persist_vars_without_grad will save variables to a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be saved.
|
|
|
|
|
"""
|
|
|
|
|
save_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
vars=None,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUCCESS_MARK_FILENAME = "_SUCCESS"
|
|
|
|
|
CHECKPOINT_PREFIX = "checkpoint"
|
|
|
|
|
MODEL_DIR = "__model__"
|
|
|
|
|
TRAINER_PREFIX = "trainer"
|
|
|
|
|
CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
|
checkpoint_dir,
|
|
|
|
|
trainer_id,
|
|
|
|
|
is_chief=False,
|
|
|
|
|
trainer_args=None,
|
|
|
|
|
main_program=None,
|
|
|
|
|
max_num_checkpoints=3):
|
|
|
|
|
"""
|
|
|
|
@ -502,22 +480,35 @@ def save_checkpoint(executor,
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param main_program
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
:param is_chief
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
|
|
|
|
|
if trainer_args and not isinstance(trainer_args, dict):
|
|
|
|
|
raise TypeError("The type of 'trainer_args' should be dict")
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
|
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
if is_chief:
|
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(executor, checkpoint_dir, main_program=None):
|
|
|
|
|
def need_load_checkpoint(checkpoint_dir):
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
if serial < 0:
|
|
|
|
|
return None
|
|
|
|
|
return serial
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
|
|
|
|
|
"""
|
|
|
|
|
Load checkpoint from a directory by executor,
|
|
|
|
|
it will find the most recent saved checkpoint file and load it auto.
|
|
|
|
@ -528,14 +519,17 @@ def load_checkpoint(executor, checkpoint_dir, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The values of 'checkpoint_dir' or 'serial' should not be None")
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
if serial is None or serial < 0:
|
|
|
|
|
raise ValueError("The values of 'serial' should not be None or <0 ")
|
|
|
|
|
|
|
|
|
|
if serial < 0:
|
|
|
|
|
return
|
|
|
|
|
if main_program is None:
|
|
|
|
|
raise ValueError("The values of 'main_program'should not be None")
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
cur_dir = _get_model_dir(cur_dir)
|
|
|
|
|
load_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -552,6 +546,68 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
os.rmdir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persist_vars_without_grad(executor, dirname, program, nest=True):
|
|
|
|
|
"""
|
|
|
|
|
load_persist_vars_without_grad will load variables from a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be loaded.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if nest:
|
|
|
|
|
dirname = _get_model_dir(dirname)
|
|
|
|
|
|
|
|
|
|
load_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
"""
|
|
|
|
|
save_persist_vars_without_grad will save variables to a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be saved.
|
|
|
|
|
"""
|
|
|
|
|
cur_dir = _get_model_dir(dirname)
|
|
|
|
|
save_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=program,
|
|
|
|
|
vars=None,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_trainer_args(dirname, trainer_id, trainer_args):
|
|
|
|
|
if not isinstance(trainer_args, dict):
|
|
|
|
|
raise TypeError("The type of 'trainer_args' should be dict")
|
|
|
|
|
cur_dir = _get_trainer_dir(dirname, trainer_id)
|
|
|
|
|
|
|
|
|
|
for name, value in trainer_args.iteritems():
|
|
|
|
|
args_file = os.path.join(cur_dir, name)
|
|
|
|
|
with open(args_file, 'w') as f:
|
|
|
|
|
f.write(str(value))
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
|
|
|
|
|
|
|
|
|
|
if not isinstance(trainer_args, list):
|
|
|
|
|
raise TypeError("The type of 'trainer_args' should be list")
|
|
|
|
|
|
|
|
|
|
ret_values = []
|
|
|
|
|
|
|
|
|
|
for arg in trainer_args:
|
|
|
|
|
cur_file = os.path.join(cur_dir, arg)
|
|
|
|
|
with open(cur_file, 'r') as f:
|
|
|
|
|
contents = f.read()
|
|
|
|
|
ret_values.append(contents.strip())
|
|
|
|
|
return ret_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_checkpoint_var(var):
|
|
|
|
|
"""
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
@ -583,7 +639,31 @@ def _get_dir_serial(dirname):
|
|
|
|
|
|
|
|
|
|
def _get_serial_dir(dirname, serial):
|
|
|
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
|
|
|
return os.path.join(dirname, serial_folder)
|
|
|
|
|
serial_dir = os.path.join(dirname, serial_folder)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(serial_dir):
|
|
|
|
|
os.makedirs(serial_dir)
|
|
|
|
|
|
|
|
|
|
return serial_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_model_dir(dirname):
|
|
|
|
|
model_dir = os.path.join(dirname, MODEL_DIR)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(model_dir):
|
|
|
|
|
os.makedirs(model_dir)
|
|
|
|
|
|
|
|
|
|
return model_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_trainer_dir(dirname, trainer_id):
|
|
|
|
|
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
|
|
|
|
|
trainer_dir = os.path.join(dirname, trainer_folder)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(trainer_dir):
|
|
|
|
|
os.makedirs(trainer_dir)
|
|
|
|
|
|
|
|
|
|
return trainer_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lru_delete(dirname, max_num_checkpoints=3):
|
|
|
|
@ -638,7 +718,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(
|
|
|
|
|
_get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME)
|
|
|
|
|
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
|
|
|
|
|
SUCCESS_MARK_FILENAME)
|
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
|
return serial
|
|
|
|
|
|
|
|
|
|