You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
688 lines
21 KiB
688 lines
21 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import sys
|
|
import logging
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import six
|
|
import time
|
|
import collections
|
|
from threading import Thread, current_thread
|
|
from contextlib import contextmanager
|
|
|
|
from paddle.fluid import unique_name, compiler
|
|
from .checkpoint_saver import SerializableBase, CheckpointSaver, PaddleModel
|
|
from paddle.fluid.framework import in_dygraph_mode, Program
|
|
|
|
g_train_epoch_range = None
|
|
g_checker = None
|
|
|
|
logger = None
|
|
|
|
generator = unique_name.UniqueNameGenerator()
|
|
|
|
CONST_CHECKPOINT = "checkpoint"
|
|
CONST_MEMORYINIT = "memory_init"
|
|
|
|
# auto checkpoint by dataloader event.
|
|
CONST_DACP_TYPE = "dacp"
|
|
# auto checkpoint by loop range.
|
|
CONST_ACP_TYPE = "acp"
|
|
g_acp_type = None
|
|
g_program_attr = {} # program_name->can_be_auto_checkpoint
|
|
|
|
|
|
def _get_logger(log_level, name="auto_checkpoint"):
|
|
global logger
|
|
if logger != None:
|
|
return logger
|
|
|
|
logger = logging.getLogger(name)
|
|
logger.setLevel(log_level)
|
|
logger.propagate = False
|
|
|
|
log_handler = logging.StreamHandler()
|
|
log_format = logging.Formatter(
|
|
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
|
|
log_handler.setFormatter(log_format)
|
|
logger.addHandler(log_handler)
|
|
|
|
return logger
|
|
|
|
|
|
def _thread_checker():
|
|
assert current_thread().name == "MainThread", \
|
|
"auto checkpoint must run under main thread"
|
|
|
|
|
|
class AutoCheckpointChecker(object):
|
|
def __init__(self):
|
|
self._run_env = None
|
|
self._platform = None
|
|
self._job_id = None
|
|
self._hdfs_home = None
|
|
self._hdfs_name = None
|
|
self._hdfs_ugi = None
|
|
self._hdfs_checkpoint_path = None
|
|
self._trainer_id = None
|
|
self._ce_test = None
|
|
|
|
self._run_env = os.getenv("PADDLE_RUNNING_ENV")
|
|
if self._run_env != "PADDLE_EDL_AUTO_CHECKPOINT":
|
|
return
|
|
|
|
try:
|
|
self._platform = os.environ["PADDLE_RUNNING_PLATFORM"]
|
|
self._job_id = os.environ["PADDLE_JOB_ID"]
|
|
self._hdfs_home = os.environ["PADDLE_EDL_HDFS_HOME"]
|
|
self._hdfs_name = os.environ["PADDLE_EDL_HDFS_NAME"]
|
|
self._hdfs_ugi = os.environ["PADDLE_EDL_HDFS_UGI"]
|
|
self._hdfs_checkpoint_path = os.environ[
|
|
"PADDLE_EDL_HDFS_CHECKPOINT_PATH"]
|
|
self._trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
|
|
|
|
self._ce_test = int(os.getenv("PADDLE_EDL_ONLY_FOR_CE_TEST", "0"))
|
|
self._fs_cache = os.getenv("PADDLE_EDL_FS_CACHE", ".cache")
|
|
|
|
self._save_checkpoint_inter = int(
|
|
os.getenv("PADDLE_EDL_SAVE_CHECKPOINT_INTER", "900")) # s
|
|
|
|
if not self._ce_test:
|
|
assert len(self._hdfs_home) > 3 and \
|
|
len(self._hdfs_name) > 6 and \
|
|
len(self._hdfs_ugi) > 3 and \
|
|
len(self._hdfs_checkpoint_path) > 0, "hdfs environ must set"
|
|
else:
|
|
assert len(self._hdfs_home) > 3 and \
|
|
len(self._hdfs_checkpoint_path) > 0, "hdfs environ must set"
|
|
|
|
except Exception as e:
|
|
logger.fatal("exception:{}".format(e))
|
|
sys.exit(1)
|
|
|
|
def get_range_checkpoint_path(self, name):
|
|
return "{}/{}/range/{}".format(self.hdfs_checkpoint_path, self.job_id,
|
|
name)
|
|
|
|
def get_exe_checkpoint_path(self, name):
|
|
return "{}/{}/exe/{}".format(self.hdfs_checkpoint_path, self.job_id,
|
|
name)
|
|
|
|
def get_job_path(self):
|
|
return "{}/{}".format(self.hdfs_checkpoint_path, self.job_id)
|
|
|
|
@property
|
|
def save_checkpoint_inter(self):
|
|
return self._save_checkpoint_inter
|
|
|
|
def valid(self):
|
|
if in_dygraph_mode():
|
|
return False
|
|
|
|
return self._run_env is not None and \
|
|
self._platform is not None and \
|
|
self._job_id is not None and \
|
|
self._hdfs_home is not None and \
|
|
self._hdfs_name is not None and \
|
|
self._hdfs_ugi is not None and \
|
|
self._hdfs_checkpoint_path is not None and \
|
|
self._trainer_id is not None
|
|
|
|
def __str__(self):
|
|
return "run_env:{} platform:{} job_id:{} \
|
|
hdfs_home:{} hdfs_name:{} hdfs_ugi:{} \
|
|
hdfs_checkpoint_path:{} trainer_id:{} ce_test".format(
|
|
self._run_env, self._platform, self._hdfs_home, self._hdfs_name,
|
|
self._hdfs_ugi, self._hdfs_checkpoint_path, self._trainer_id,
|
|
self._ce_test)
|
|
|
|
@property
|
|
def trainer_id(self):
|
|
return self._trainer_id
|
|
|
|
@property
|
|
def run_env(self):
|
|
return self._run_env
|
|
|
|
@property
|
|
def platform(self):
|
|
return self._platform
|
|
|
|
@property
|
|
def job_id(self):
|
|
return self._job_id
|
|
|
|
@property
|
|
def hdfs_home(self):
|
|
return self._hdfs_home
|
|
|
|
@property
|
|
def hdfs_name(self):
|
|
return self._hdfs_name
|
|
|
|
@property
|
|
def ce_test(self):
|
|
return self._ce_test
|
|
|
|
@property
|
|
def hdfs_ugi(self):
|
|
return self._hdfs_ugi
|
|
|
|
@property
|
|
def hdfs_checkpoint_path(self):
|
|
return self._hdfs_checkpoint_path
|
|
|
|
@staticmethod
|
|
def generate_range_name():
|
|
return generator("_range_")
|
|
|
|
|
|
class ExeTrainStatus(SerializableBase):
|
|
def __init__(self):
|
|
self._epoch_no = -1 # start epoch_no
|
|
self._hash_key = None
|
|
self._key = None
|
|
self._checkpoint_path = None
|
|
self._checkpoint_no = None
|
|
self._restored_from = None
|
|
self._exe = None
|
|
self._program = None
|
|
self._exe_name = None
|
|
self._program_name = None
|
|
|
|
self._file_name = "exe_train_status"
|
|
|
|
def __eq__(self, t):
|
|
return self._epoch_no == t._epoch_no and \
|
|
self._hash_key == t._hash_key and \
|
|
self._key == t._key and \
|
|
self._checkpoint_path == t._checkpoint_path and \
|
|
self._checkpoint_no == t._checkpoint_no and \
|
|
self._exe_name == t._exe_name and \
|
|
self._program_name == t._program_name
|
|
|
|
def __ne__(self, t):
|
|
return not self == t
|
|
|
|
def serialize(self, path):
|
|
file_name = "{}/{}".format(path, self._file_name)
|
|
with open(file_name, 'w') as f:
|
|
s = self._serialize()
|
|
f.write(s)
|
|
|
|
def _serialize(self, pop_keys=["restored_from"]):
|
|
d = self._to_dict()
|
|
for k in pop_keys:
|
|
d.pop(k, None)
|
|
return json.dumps(d)
|
|
|
|
def deserialize(self, path):
|
|
d = None
|
|
file_name = "{}/{}".format(path, self._file_name)
|
|
with open(file_name, 'r') as f:
|
|
s = f.read()
|
|
self._deserialize(s)
|
|
|
|
def _deserialize(self, s):
|
|
d = json.loads(s)
|
|
self._epoch_no = d["epoch_no"]
|
|
self._key = d["key"]
|
|
self._hash_key = d["hash_key"]
|
|
self._checkpoint_path = d["checkpoint_path"]
|
|
self._checkpoint_no = d["checkpoint_no"]
|
|
self._exe_name = d["exe_name"]
|
|
self._program_name = d["program_name"]
|
|
|
|
def _to_dict(self):
|
|
return {
|
|
"epoch_no": self._epoch_no,
|
|
"key": self._key,
|
|
"hash_key": self._hash_key,
|
|
"checkpoint_path": self._checkpoint_path,
|
|
"restored_from": self._restored_from,
|
|
"exe_name": self._exe_name,
|
|
"program_name": self._program_name,
|
|
"checkpoint_no": self._checkpoint_no
|
|
}
|
|
|
|
def __str__(self):
|
|
return self._serialize([])
|
|
|
|
|
|
class TrainEpochRange(SerializableBase):
|
|
def __init__(self,
|
|
max_epoch_num,
|
|
name,
|
|
checkpoint_inter=None,
|
|
restored=True):
|
|
self._max_epoch_num = max_epoch_num
|
|
self._epoch_no = -1 # current epoch_no
|
|
self._name = name
|
|
self._restored_from = None
|
|
self._exe_status = {}
|
|
self._flag_generated = False
|
|
|
|
self._checker = g_checker
|
|
if checkpoint_inter is not None:
|
|
self._save_checkpoint_inter = checkpoint_inter
|
|
else:
|
|
self._save_checkpoint_inter = self._checker.save_checkpoint_inter
|
|
assert self._save_checkpoint_inter >= 0, "checkpointer:{} must >=0".format(
|
|
self._save_checkpoint_inter)
|
|
self._last_checkpoint_time = time.time()
|
|
|
|
self._load_cp_nos = None
|
|
self._checkpoint_epoch_no = None
|
|
|
|
if not self._checker.valid():
|
|
return
|
|
|
|
self._file_name = "range_train_status"
|
|
|
|
if not restored:
|
|
return
|
|
|
|
self._checkpoint_path = self._checker.get_range_checkpoint_path(name)
|
|
|
|
config = {
|
|
"fs.default.name": self._checker.hdfs_name,
|
|
"hadoop.job.ugi": self._checker.hdfs_ugi
|
|
}
|
|
|
|
if self._checker.ce_test:
|
|
config = None
|
|
|
|
from paddle.distributed.fleet.utils.fs import HDFSClient
|
|
self._hdfs = HDFSClient(self._checker.hdfs_home, config)
|
|
|
|
self._cper = CheckpointSaver(self._hdfs)
|
|
|
|
_thread_checker()
|
|
|
|
self._get_last_valid_checkpoint()
|
|
|
|
def _look_for_valid(self, cp_nos):
|
|
cps = []
|
|
epoch_no = -1
|
|
for i in cp_nos[::-1]:
|
|
t = TrainEpochRange(self._max_epoch_num, self.name, restored=False)
|
|
self._cper.load_checkpoint(
|
|
self._checkpoint_path, [t],
|
|
self._checker.trainer_id,
|
|
checkpoint_no=i,
|
|
local_cache_path=self._checker._fs_cache)
|
|
cps.append(t)
|
|
logger.debug("look for valid:{} t:{}".format(i, t._serialize()))
|
|
if epoch_no < 0:
|
|
epoch_no = t._epoch_no
|
|
else:
|
|
if epoch_no - t._epoch_no >= 1:
|
|
return t, i
|
|
return None, None
|
|
|
|
def _get_last_valid_checkpoint(self):
|
|
self._load_cp_nos = self._cper.get_checkpoint_no(self._checkpoint_path)
|
|
logger.info("find checkpoint nos:{}".format(self._load_cp_nos))
|
|
|
|
if len(self._load_cp_nos) < 1:
|
|
self._restored_from = CONST_MEMORYINIT
|
|
return
|
|
|
|
if g_acp_type == CONST_ACP_TYPE:
|
|
# get the last one
|
|
self._cper.load_checkpoint(
|
|
self._checkpoint_path, [self],
|
|
self._checker.trainer_id,
|
|
local_cache_path=self._checker._fs_cache)
|
|
self._restored_from = CONST_CHECKPOINT
|
|
self._checkpoint_epoch_no = self._epoch_no
|
|
|
|
logger.info("load tain_epoch_range checkpoint:{}".format(
|
|
self._serialize()))
|
|
|
|
elif g_acp_type == CONST_DACP_TYPE:
|
|
t, i = self._look_for_valid(self._load_cp_nos)
|
|
if t is None:
|
|
self._restored_from = CONST_MEMORYINIT
|
|
return
|
|
|
|
self._cper.load_checkpoint(
|
|
self._checkpoint_path, [self],
|
|
self._checker.trainer_id,
|
|
checkpoint_no=i,
|
|
local_cache_path=self._checker._fs_cache)
|
|
|
|
self._restored_from = CONST_CHECKPOINT
|
|
self._checkpoint_epoch_no = self._epoch_no
|
|
logger.info("load tain_epoch_range checkpoint:{}".format(
|
|
self._serialize()))
|
|
else:
|
|
assert False, "not supported acp_type:{}".format(g_acp_type)
|
|
|
|
def _to_dict(self):
|
|
d = {
|
|
"max_epoch_num": self._max_epoch_num,
|
|
"epoch_no": self._epoch_no,
|
|
"name": self._name,
|
|
"checkpoint_path": self._checkpoint_path,
|
|
"restored_from": self._restored_from,
|
|
"checkpoint_epoch_no": self._checkpoint_epoch_no
|
|
}
|
|
return d
|
|
|
|
def __str__(self):
|
|
return self._serialize([])
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
def serialize(self, path):
|
|
file_name = "{}/{}".format(path, self._file_name)
|
|
with open(file_name, 'w') as f:
|
|
s = self._serialize()
|
|
f.write(s)
|
|
|
|
def _serialize(self, pop_keys=["restored_from", "checkpoint_epoch_no"]):
|
|
# self
|
|
d = self._to_dict()
|
|
for k in pop_keys:
|
|
d.pop(k, None)
|
|
|
|
# registerd exes
|
|
d["exe_status"] = {}
|
|
e = d["exe_status"]
|
|
for k, t in six.iteritems(self._exe_status):
|
|
e[t._key] = t._serialize()
|
|
return json.dumps(d)
|
|
|
|
@property
|
|
def restored_from(self):
|
|
return self._restored_from
|
|
|
|
def deserialize(self, path):
|
|
d = None
|
|
file_name = "{}/{}".format(path, self._file_name)
|
|
with open(file_name, 'r') as f:
|
|
d = json.load(f)
|
|
|
|
# self
|
|
self._max_epoch_num = d["max_epoch_num"]
|
|
self._epoch_no = d["epoch_no"]
|
|
self._name = d["name"]
|
|
self._checkpoint_path = d["checkpoint_path"]
|
|
|
|
# exes status
|
|
e = d["exe_status"]
|
|
for k, v in six.iteritems(e):
|
|
t = ExeTrainStatus()
|
|
t._deserialize(v)
|
|
self._exe_status[k] = t
|
|
|
|
def next(self):
|
|
_thread_checker()
|
|
|
|
if self._max_epoch_num < 0:
|
|
self._max_epoch_num = sys.maxint
|
|
|
|
assert self._epoch_no >= -1, "self._epoch_no:{} must >=-1".format(
|
|
self._epoch_no)
|
|
|
|
self._last_checkpoint_time = time.time()
|
|
start = self._epoch_no + 1
|
|
logger.info("started epoch_no:{} max_epoch_num:{}".format(
|
|
start, self._max_epoch_num))
|
|
|
|
for i in range(start, self._max_epoch_num):
|
|
self._epoch_no = i
|
|
yield i
|
|
|
|
self.save_checkpoint()
|
|
|
|
def get(self):
|
|
return self._epoch_no
|
|
|
|
def save_checkpoint(self):
|
|
# not save last one because exe and program can't be restored.
|
|
if self._checker.trainer_id == 0:
|
|
|
|
if time.time() - self._last_checkpoint_time >= \
|
|
self._save_checkpoint_inter:
|
|
if g_acp_type == CONST_ACP_TYPE:
|
|
# not save the last one
|
|
if self._max_epoch_num > 0 and self._epoch_no != self._max_epoch_num - 1:
|
|
self._save_checkpoint()
|
|
elif g_acp_type == CONST_DACP_TYPE:
|
|
self._save_checkpoint()
|
|
else:
|
|
assert False, "not supported acp_type:{}".format(g_acp_type)
|
|
self._last_checkpoint_time = time.time()
|
|
|
|
def _save_checkpoint(self):
|
|
"""
|
|
status => /jobid/xxx_range_xx/range/
|
|
model => /exe/
|
|
"""
|
|
if not self._checker.valid():
|
|
return
|
|
|
|
e = self._exe_status
|
|
for k, t in six.iteritems(self._exe_status):
|
|
m = PaddleModel(t._exe, t._program)
|
|
p = self._checker.get_exe_checkpoint_path(t._hash_key)
|
|
t._epoch_no = self.get()
|
|
path, checkpoint_no = self._cper.save_checkpoint(
|
|
p, [m],
|
|
self._checker.trainer_id,
|
|
local_cache_path=self._checker._fs_cache)
|
|
# index info
|
|
t._checkpoint_path = path
|
|
t._checkpoint_no = checkpoint_no
|
|
|
|
e[t._key] = t
|
|
|
|
logger.debug("save executor checkpoint:{}".format(t._serialize()))
|
|
|
|
if len(self._exe_status) > 0:
|
|
self._cper.save_checkpoint(
|
|
self._checkpoint_path, [self],
|
|
local_cache_path=self._checker._fs_cache)
|
|
logger.info("save train_epoch_range checkpoint:{}".format(
|
|
self._serialize()))
|
|
|
|
self._generate_flag()
|
|
|
|
def _generate_flag(self):
|
|
if self._flag_generated:
|
|
return
|
|
|
|
name = "can_be_auto_checkpoint.flag"
|
|
path = self._checker.get_job_path() + "/" + name
|
|
logger.info("this job can_be_auto_checkpoint")
|
|
self._hdfs.mkdirs(self._checker.get_job_path())
|
|
self._hdfs.touch(path, exist_ok=True)
|
|
|
|
self._flag_generated = True
|
|
|
|
|
|
def _get_train_epoch_range():
|
|
return g_train_epoch_range
|
|
|
|
|
|
def _check_program_oprole(program):
|
|
global_block = program.global_block()
|
|
has_backward = False
|
|
has_opt = False
|
|
for idx, op in enumerate(global_block.ops):
|
|
if op._is_backward_op():
|
|
has_backward = True
|
|
|
|
if op._is_optimize_op():
|
|
has_opt = True
|
|
|
|
if has_backward and has_opt:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _can_auto_checkpoint(prog):
|
|
if not isinstance(prog, compiler.CompiledProgram) and \
|
|
not isinstance(prog, Program):
|
|
return False
|
|
|
|
if isinstance(prog, compiler.CompiledProgram):
|
|
if prog._program is None or \
|
|
prog._program._is_distributed:
|
|
return False
|
|
else:
|
|
if prog._is_distributed:
|
|
return False
|
|
|
|
program = _get_valid_program(prog)
|
|
|
|
if program._auto_checkpoint_name in g_program_attr:
|
|
if not g_program_attr[program._auto_checkpoint_name]:
|
|
return False
|
|
else:
|
|
ret = False
|
|
if isinstance(program, compiler.CompiledProgram):
|
|
ret = _check_program_oprole(program._program)
|
|
else:
|
|
ret = _check_program_oprole(program)
|
|
|
|
g_program_attr[program._auto_checkpoint_name] = ret
|
|
if not ret:
|
|
logger.debug("program {} need't to auto checkpoint".format(
|
|
program._auto_checkpoint_name))
|
|
return False
|
|
|
|
return g_checker.valid() and g_train_epoch_range is not None
|
|
|
|
|
|
def _get_running_key(exe_name, program_name):
|
|
return "{}_{}".format(exe_name, program_name)
|
|
|
|
|
|
def _get_checker():
|
|
_get_logger(20)
|
|
global g_checker
|
|
if g_checker is None:
|
|
g_checker = AutoCheckpointChecker()
|
|
|
|
return g_checker
|
|
|
|
|
|
def _normal_yield(max_epoch_num):
|
|
if max_epoch_num < 0:
|
|
max_epoch_num = sys.maxint
|
|
for i in range(0, max_epoch_num):
|
|
yield i
|
|
|
|
return
|
|
|
|
|
|
def train_epoch_range(max_epoch_num, save_checkpoint_inter=None):
|
|
global g_acp_type
|
|
if not _get_checker().valid():
|
|
logger.warning(
|
|
"auto checkpoint will take effect automaticly on PaddleCloud")
|
|
for i in _normal_yield(max_epoch_num):
|
|
yield i
|
|
|
|
return
|
|
|
|
if g_acp_type == CONST_DACP_TYPE:
|
|
for i in _normal_yield(max_epoch_num):
|
|
yield i
|
|
|
|
return
|
|
|
|
g_acp_type = CONST_ACP_TYPE
|
|
logger.info("acp_type:{}".format(g_acp_type))
|
|
|
|
global g_train_epoch_range
|
|
try:
|
|
g_train_epoch_range = TrainEpochRange(
|
|
max_epoch_num,
|
|
g_checker.generate_range_name(),
|
|
checkpoint_inter=save_checkpoint_inter)
|
|
|
|
for i in g_train_epoch_range.next():
|
|
yield i
|
|
finally:
|
|
g_train_epoch_range = None
|
|
|
|
|
|
def _get_valid_program(prog):
|
|
if isinstance(prog, compiler.CompiledProgram):
|
|
return prog._program
|
|
|
|
return prog
|
|
|
|
|
|
def _auto_checkpoint(exe, prog):
|
|
_get_checker()
|
|
|
|
assert exe._auto_checkpoint_name != None
|
|
if not _can_auto_checkpoint(prog):
|
|
return
|
|
|
|
program = _get_valid_program(prog)
|
|
assert program._auto_checkpoint_name != None
|
|
|
|
exe_status = g_train_epoch_range._exe_status
|
|
key = _get_running_key(exe._auto_checkpoint_name,
|
|
program._auto_checkpoint_name)
|
|
|
|
if g_train_epoch_range.restored_from == CONST_CHECKPOINT:
|
|
assert key in exe_status, "when restored key:{} must be in train_epoch_range:{}".format(
|
|
key, g_train_epoch_range)
|
|
|
|
t = None
|
|
if key in exe_status:
|
|
t = exe_status[key]
|
|
if t._restored_from is None:
|
|
a = CheckpointSaver(g_train_epoch_range._hdfs)
|
|
m = PaddleModel(exe, program)
|
|
a.load_checkpoint(
|
|
g_checker.get_exe_checkpoint_path(key), [m],
|
|
trainer_id=g_checker.trainer_id,
|
|
checkpoint_no=t._checkpoint_no,
|
|
local_cache_path=g_checker._fs_cache)
|
|
t._restored_from = CONST_CHECKPOINT
|
|
logger.info("load executor checkpoint {}".format(t))
|
|
t._exe = exe
|
|
t._program = program
|
|
t._epoch_no = g_train_epoch_range.get()
|
|
else:
|
|
t = ExeTrainStatus()
|
|
t._epoch_no = g_train_epoch_range.get()
|
|
t._hash_key = key
|
|
t._key = key
|
|
t._restored_from = CONST_MEMORYINIT
|
|
t._exe = exe
|
|
t._program = program
|
|
t._exe_name = exe._auto_checkpoint_name
|
|
t._program_name = program._auto_checkpoint_name
|
|
|
|
# register this <exe,program,io>
|
|
exe_status[key] = t
|
|
|
|
logger.info("not found checkpoint, so train from epoch 0")
|
|
|
|
_thread_checker()
|