|
|
|
@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
|
|
|
|
return metirc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess():
|
|
|
|
|
def preprocess(is_train=False):
|
|
|
|
|
FLAGS = ArgsParser().parse_args()
|
|
|
|
|
config = load_config(FLAGS.config)
|
|
|
|
|
merge_config(FLAGS.opt)
|
|
|
|
@ -350,15 +350,17 @@ def preprocess():
|
|
|
|
|
device = paddle.set_device(device)
|
|
|
|
|
|
|
|
|
|
config['Global']['distributed'] = dist.get_world_size() != 1
|
|
|
|
|
|
|
|
|
|
if is_train:
|
|
|
|
|
# save_config
|
|
|
|
|
save_model_dir = config['Global']['save_model_dir']
|
|
|
|
|
os.makedirs(save_model_dir, exist_ok=True)
|
|
|
|
|
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
|
|
|
|
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
|
|
|
|
|
|
|
|
|
|
logger = get_logger(
|
|
|
|
|
name='root', log_file='{}/train.log'.format(save_model_dir))
|
|
|
|
|
yaml.dump(
|
|
|
|
|
dict(config), f, default_flow_style=False, sort_keys=False)
|
|
|
|
|
log_file = '{}/train.log'.format(save_model_dir)
|
|
|
|
|
else:
|
|
|
|
|
log_file = None
|
|
|
|
|
logger = get_logger(name='root', log_file=log_file)
|
|
|
|
|
if config['Global']['use_visualdl']:
|
|
|
|
|
from visualdl import LogWriter
|
|
|
|
|
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
|
|
|
|