|
|
|
@ -179,9 +179,9 @@ def train(config,
|
|
|
|
|
if 'start_epoch' in best_model_dict:
|
|
|
|
|
start_epoch = best_model_dict['start_epoch']
|
|
|
|
|
else:
|
|
|
|
|
start_epoch = 0
|
|
|
|
|
start_epoch = 1
|
|
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, epoch_num):
|
|
|
|
|
for epoch in range(start_epoch, epoch_num + 1):
|
|
|
|
|
if epoch > 0:
|
|
|
|
|
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
|
|
|
|
train_batch_cost = 0.0
|
|
|
|
@ -216,7 +216,6 @@ def train(config,
|
|
|
|
|
stats['lr'] = lr
|
|
|
|
|
train_stats.update(stats)
|
|
|
|
|
|
|
|
|
|
#cal_metric_during_train = False
|
|
|
|
|
if cal_metric_during_train: # onlt rec and cls need
|
|
|
|
|
batch = [item.numpy() for item in batch]
|
|
|
|
|
post_result = post_process_class(preds, batch[1])
|
|
|
|
|