port
tangwei12 7 years ago
parent 620999c917
commit 8e01f3b948

@ -13,6 +13,7 @@
# limitations under the License.
import os
import errno
import time
import shutil
@ -881,11 +882,9 @@ def save_checkpoint(executor,
if trainer_args:
assert isinstance(trainer_args, dict)
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir)
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
@ -1251,6 +1250,20 @@ def _is_checkpoint_var(var):
return var.persistable
def _make_chekcpoint_dirs(dirs):
assert dirs is not None
if os.path.isfile(dirs):
raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)
if not os.path.isdir(dirs):
try:
os.makedirs(dirs)
except OSError as err:
if err.errno != errno.EEXIST:
raise err
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
@ -1264,38 +1277,27 @@ def _get_dir_serial(dirname):
def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)
if not os.path.isdir(serial_dir):
os.makedirs(serial_dir)
_make_chekcpoint_dirs(serial_dir)
return serial_dir
def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
_make_chekcpoint_dirs(model_dir)
return model_dir
def _get_lookuptable_dir(dirname):
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
if not os.path.isdir(lookuptable_dir):
os.makedirs(lookuptable_dir)
_make_chekcpoint_dirs(lookuptable_dir)
return lookuptable_dir
def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)
if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)
_make_chekcpoint_dirs(trainer_dir)
return trainer_dir
@ -1314,7 +1316,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3):
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = _get_serial_dir(dirname, serial)
shutil.rmtree(cur_dir)
try:
shutil.rmtree(cur_dir)
except OSError as err:
if err.errno != errno.ENOENT:
raise err
def _write_success(dirname):

Loading…
Cancel
Save