|
|
|
@ -177,6 +177,8 @@ def train(config,
|
|
|
|
|
model_average = False
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
|
|
|
|
|
|
|
|
|
if 'start_epoch' in best_model_dict:
|
|
|
|
|
start_epoch = best_model_dict['start_epoch']
|
|
|
|
|
else:
|
|
|
|
@ -195,7 +197,7 @@ def train(config,
|
|
|
|
|
break
|
|
|
|
|
lr = optimizer.get_lr()
|
|
|
|
|
images = batch[0]
|
|
|
|
|
if config['Architecture']['algorithm'] == "SRN":
|
|
|
|
|
if use_srn:
|
|
|
|
|
others = batch[-4:]
|
|
|
|
|
preds = model(images, others)
|
|
|
|
|
model_average = True
|
|
|
|
@ -251,8 +253,12 @@ def train(config,
|
|
|
|
|
min_average_window=10000,
|
|
|
|
|
max_average_window=15625)
|
|
|
|
|
Model_Average.apply()
|
|
|
|
|
cur_metric = eval(model, valid_dataloader, post_process_class,
|
|
|
|
|
eval_class)
|
|
|
|
|
cur_metric = eval(
|
|
|
|
|
model,
|
|
|
|
|
valid_dataloader,
|
|
|
|
|
post_process_class,
|
|
|
|
|
eval_class,
|
|
|
|
|
use_srn=use_srn)
|
|
|
|
|
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
|
|
|
|
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
|
|
|
|
logger.info(cur_metric_str)
|
|
|
|
@ -316,7 +322,8 @@ def train(config,
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def eval(model, valid_dataloader, post_process_class, eval_class):
|
|
|
|
|
def eval(model, valid_dataloader, post_process_class, eval_class,
|
|
|
|
|
use_srn=False):
|
|
|
|
|
model.eval()
|
|
|
|
|
with paddle.no_grad():
|
|
|
|
|
total_frame = 0.0
|
|
|
|
@ -327,7 +334,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
|
|
|
|
break
|
|
|
|
|
images = batch[0]
|
|
|
|
|
start = time.time()
|
|
|
|
|
if "SRN" in str(model.head):
|
|
|
|
|
|
|
|
|
|
if use_srn:
|
|
|
|
|
others = batch[-4:]
|
|
|
|
|
preds = model(images, others)
|
|
|
|
|
else:
|
|
|
|
|