|
|
|
@ -24,7 +24,8 @@ __all__ = [
|
|
|
|
|
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
|
|
|
|
|
'load_persistables', 'save_inference_model', 'load_inference_model',
|
|
|
|
|
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
|
|
|
|
|
'clean_checkpoint'
|
|
|
|
|
'clean_checkpoint', 'load_persist_vars_without_grad',
|
|
|
|
|
'save_persist_vars_without_grad'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -455,6 +456,33 @@ def get_parameter_value_by_name(name, executor, program=None):
|
|
|
|
|
return get_parameter_value(var, executor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
"""
|
|
|
|
|
load_persist_vars_without_grad will load variables from a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be loaded.
|
|
|
|
|
"""
|
|
|
|
|
load_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
"""
|
|
|
|
|
save_persist_vars_without_grad will save variables to a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be saved.
|
|
|
|
|
"""
|
|
|
|
|
save_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
vars=None,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUCCESS_MARK_FILENAME = "_SUCCESS"
|
|
|
|
|
CHECKPOINT_PREFIX = "checkpoint"
|
|
|
|
|
CHECKPOINT_SEPARATOR = "_"
|
|
|
|
@ -491,13 +519,7 @@ def save_checkpoint(executor,
|
|
|
|
|
serial += 1
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
save_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=main_program,
|
|
|
|
|
vars=None,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
load_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
@ -521,13 +543,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
load_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=main_program,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
load_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|