|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import errno
|
|
|
|
|
import time
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
@ -881,11 +882,9 @@ def save_checkpoint(executor,
|
|
|
|
|
if trainer_args:
|
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
is_chief = trainer_id == 0
|
|
|
|
|
|
|
|
|
|
_make_chekcpoint_dirs(checkpoint_dir)
|
|
|
|
|
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
|
|
|
|
@ -1251,6 +1250,20 @@ def _is_checkpoint_var(var):
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_chekcpoint_dirs(dirs):
|
|
|
|
|
assert dirs is not None
|
|
|
|
|
|
|
|
|
|
if os.path.isfile(dirs):
|
|
|
|
|
raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(dirs):
|
|
|
|
|
try:
|
|
|
|
|
os.makedirs(dirs)
|
|
|
|
|
except OSError as err:
|
|
|
|
|
if err.errno != errno.EEXIST:
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_dir_serial(dirname):
|
|
|
|
|
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
|
|
|
|
|
|
|
|
|
@ -1264,38 +1277,27 @@ def _get_dir_serial(dirname):
|
|
|
|
|
def _get_serial_dir(dirname, serial):
|
|
|
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
|
|
|
serial_dir = os.path.join(dirname, serial_folder)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(serial_dir):
|
|
|
|
|
os.makedirs(serial_dir)
|
|
|
|
|
_make_chekcpoint_dirs(serial_dir)
|
|
|
|
|
|
|
|
|
|
return serial_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_model_dir(dirname):
|
|
|
|
|
model_dir = os.path.join(dirname, MODEL_DIR)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(model_dir):
|
|
|
|
|
os.makedirs(model_dir)
|
|
|
|
|
|
|
|
|
|
_make_chekcpoint_dirs(model_dir)
|
|
|
|
|
return model_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_lookuptable_dir(dirname):
|
|
|
|
|
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(lookuptable_dir):
|
|
|
|
|
os.makedirs(lookuptable_dir)
|
|
|
|
|
|
|
|
|
|
_make_chekcpoint_dirs(lookuptable_dir)
|
|
|
|
|
return lookuptable_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_trainer_dir(dirname, trainer_id):
|
|
|
|
|
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
|
|
|
|
|
trainer_dir = os.path.join(dirname, trainer_folder)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(trainer_dir):
|
|
|
|
|
os.makedirs(trainer_dir)
|
|
|
|
|
|
|
|
|
|
_make_chekcpoint_dirs(trainer_dir)
|
|
|
|
|
return trainer_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1314,7 +1316,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3):
|
|
|
|
|
serials = serials[max_num_checkpoints:]
|
|
|
|
|
for serial in serials:
|
|
|
|
|
cur_dir = _get_serial_dir(dirname, serial)
|
|
|
|
|
try:
|
|
|
|
|
shutil.rmtree(cur_dir)
|
|
|
|
|
except OSError as err:
|
|
|
|
|
if err.errno != errno.ENOENT:
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _write_success(dirname):
|
|
|
|
|