|
|
|
@ -492,7 +492,7 @@ def save_checkpoint(executor,
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
|
|
|
|
|
serial = _get_latest_checkpoint_dir(checkpoint_dir) + 1
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
|
|
|
|
|
save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
@ -505,11 +505,11 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
def get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
If the directory have checkpoint files, it will return lastest checkpoint directory serial number
|
|
|
|
|
If the directory have checkpoint files, it will return latest checkpoint directory serial number
|
|
|
|
|
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
"""
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
serial = _get_latest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
if serial < 0:
|
|
|
|
|
return None
|
|
|
|
|
return serial
|
|
|
|
@ -639,14 +639,14 @@ def _is_checkpoint_var(var):
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
|
|
|
return False
|
|
|
|
|
# @GRAD are named for gradient varibales, checkpoint will not save it.
|
|
|
|
|
# @GRAD are named for gradient variables, checkpoint will not save it.
|
|
|
|
|
if "@GRAD" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
# .trainer_ are named for distribute trian variables, checkpoint will not save it.
|
|
|
|
|
# .trainer_ are named for distribute train variables, checkpoint will not save it.
|
|
|
|
|
if ".trainer_" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# .block is named for distribute trian variables, checkpoint will not save it.
|
|
|
|
|
# .block is named for distribute train variables, checkpoint will not save it.
|
|
|
|
|
if ".block" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
@ -656,7 +656,6 @@ def _is_checkpoint_var(var):
|
|
|
|
|
def _get_dir_serial(dirname):
|
|
|
|
|
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
|
|
|
|
|
|
|
|
|
|
serial_num = -1
|
|
|
|
|
try:
|
|
|
|
|
serial_num = int(serial)
|
|
|
|
|
except ValueError:
|
|
|
|
@ -723,7 +722,7 @@ def _write_success(dirname):
|
|
|
|
|
f.write(now)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
def _get_latest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
|
|
|
|
|
|
|
|
|
|