|
|
|
@ -13,21 +13,18 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.evaluator import Evaluator
|
|
|
|
|
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
|
|
|
|
|
from . import core
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'save_vars',
|
|
|
|
|
'save_params',
|
|
|
|
|
'save_persistables',
|
|
|
|
|
'load_vars',
|
|
|
|
|
'load_params',
|
|
|
|
|
'load_persistables',
|
|
|
|
|
'save_inference_model',
|
|
|
|
|
'load_inference_model',
|
|
|
|
|
'get_inference_program',
|
|
|
|
|
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
|
|
|
|
|
'load_persistables', 'save_inference_model', 'load_inference_model',
|
|
|
|
|
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
|
|
|
|
|
'clean_checkpoint'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -195,6 +192,8 @@ def load_vars(executor,
|
|
|
|
|
load_var_map = {}
|
|
|
|
|
for each_var in vars:
|
|
|
|
|
assert isinstance(each_var, Variable)
|
|
|
|
|
if each_var.type == core.VarDesc.VarType.RAW:
|
|
|
|
|
continue
|
|
|
|
|
new_var = _clone_var_in_block_(load_block, each_var)
|
|
|
|
|
if filename is None:
|
|
|
|
|
load_block.append_op(
|
|
|
|
@ -454,3 +453,192 @@ def get_parameter_value_by_name(name, executor, program=None):
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
var = program.global_block().var(name)
|
|
|
|
|
return get_parameter_value(var, executor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUCCESS_MARK_FILENAME = "_SUCCESS"
|
|
|
|
|
CHECKPOINT_PREFIX = "checkpoint"
|
|
|
|
|
CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
save_interval_secs=600,
|
|
|
|
|
main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
|
|
|
|
|
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
|
|
|
|
|
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
|
|
|
|
|
The interval between two saved checkpoints must greater than save_interval_secs.
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
:param save_interval_secs
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
if serial >= 0 and not _interval_secs_exceed(
|
|
|
|
|
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
serial += 1
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
save_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=main_program,
|
|
|
|
|
vars=None,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Load checkpoint from a directory by executor,
|
|
|
|
|
it will find the most recent saved checkpoint file and load it auto.
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
if serial < 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
load_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=main_program,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
"""
|
|
|
|
|
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
|
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
|
|
|
|
|
|
|
|
|
|
if delete_dir and not os.listdir(checkpoint_dir):
|
|
|
|
|
os.rmdir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_serial_dir(serial, checkpoint_dir):
|
|
|
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
|
|
|
return os.path.join(checkpoint_dir, serial_folder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_checkpoint_var(var):
|
|
|
|
|
"""
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
|
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
|
|
|
|
|
|
|
|
|
|
:param var
|
|
|
|
|
"""
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if var.name.endswith("@GRAD"):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _interval_secs_exceed(dirname, save_interval_secs):
|
|
|
|
|
dir_time = os.path.getmtime(dirname)
|
|
|
|
|
if save_interval_secs > (time.time() - dir_time):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lru_delete(dirname, max_num_checkpoints=3):
|
|
|
|
|
dirs = os.listdir(dirname)
|
|
|
|
|
serials = []
|
|
|
|
|
for serial in dirs:
|
|
|
|
|
try:
|
|
|
|
|
serials.append(int(serial))
|
|
|
|
|
except ValueError:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if len(serials) <= max_num_checkpoints:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
serials.sort(reverse=True)
|
|
|
|
|
serials = serials[max_num_checkpoints:]
|
|
|
|
|
for serial in serials:
|
|
|
|
|
cur_dir = os.path.join(dirname, str(serial))
|
|
|
|
|
shutil.rmtree(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _write_success(dirname):
|
|
|
|
|
"""
|
|
|
|
|
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
|
|
|
|
|
|
|
|
|
|
:param dirname
|
|
|
|
|
"""
|
|
|
|
|
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
|
|
|
|
|
with open(success_file, 'a') as f:
|
|
|
|
|
now = time.ctime()
|
|
|
|
|
f.write(now)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
|
|
|
|
|
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
"""
|
|
|
|
|
if not checkpoint_dir.strip():
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
def has_success(checkpoint_dir, cur_dir):
|
|
|
|
|
"""
|
|
|
|
|
is _SUCCESS in this dir
|
|
|
|
|
"""
|
|
|
|
|
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
int(serial)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(
|
|
|
|
|
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
|
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
|
return int(serial)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
current_dir = -1
|
|
|
|
|
dirs = os.listdir(checkpoint_dir)
|
|
|
|
|
for cur_dir in dirs:
|
|
|
|
|
success_num = has_success(checkpoint_dir, cur_dir)
|
|
|
|
|
if success_num > current_dir:
|
|
|
|
|
current_dir = success_num
|
|
|
|
|
return current_dir
|
|
|
|
|