|
|
|
@ -109,9 +109,14 @@ def init_model(config, program, exe):
|
|
|
|
|
"""
|
|
|
|
|
checkpoints = config['Global'].get('checkpoints')
|
|
|
|
|
if checkpoints:
|
|
|
|
|
path = checkpoints
|
|
|
|
|
fluid.load(program, path, exe)
|
|
|
|
|
logger.info("Finish initing model from {}".format(path))
|
|
|
|
|
if os.path.exists(checkpoints + '.pdparams'):
|
|
|
|
|
path = checkpoints
|
|
|
|
|
fluid.load(program, path, exe)
|
|
|
|
|
logger.info("Finish initing model from {}".format(path))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Model checkpoints {} does not exists,"
|
|
|
|
|
"check if you lost the file prefix.".format(checkpoints + '.pdparams'))
|
|
|
|
|
|
|
|
|
|
pretrain_weights = config['Global'].get('pretrain_weights')
|
|
|
|
|
if pretrain_weights:
|
|
|
|
|