|
|
|
@ -185,12 +185,15 @@ def train(config,
|
|
|
|
|
for epoch in range(start_epoch, epoch_num):
|
|
|
|
|
if epoch > 0:
|
|
|
|
|
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
|
|
|
|
|
|
|
|
|
train_batch_cost = 0.0
|
|
|
|
|
train_reader_cost = 0.0
|
|
|
|
|
batch_sum = 0
|
|
|
|
|
batch_start = time.time()
|
|
|
|
|
for idx, batch in enumerate(train_dataloader):
|
|
|
|
|
train_reader_cost += time.time() - batch_start
|
|
|
|
|
if idx >= len(train_dataloader):
|
|
|
|
|
break
|
|
|
|
|
lr = optimizer.get_lr()
|
|
|
|
|
t1 = time.time()
|
|
|
|
|
images = batch[0]
|
|
|
|
|
preds = model(images)
|
|
|
|
|
loss = loss_class(preds, batch)
|
|
|
|
@ -198,6 +201,10 @@ def train(config,
|
|
|
|
|
avg_loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
optimizer.clear_grad()
|
|
|
|
|
|
|
|
|
|
train_batch_cost += time.time() - batch_start
|
|
|
|
|
batch_sum += len(images)
|
|
|
|
|
|
|
|
|
|
if not isinstance(lr_scheduler, float):
|
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
|
|
|
@ -213,9 +220,6 @@ def train(config,
|
|
|
|
|
metirc = eval_class.get_metric()
|
|
|
|
|
train_stats.update(metirc)
|
|
|
|
|
|
|
|
|
|
t2 = time.time()
|
|
|
|
|
train_batch_elapse = t2 - t1
|
|
|
|
|
|
|
|
|
|
if vdl_writer is not None and dist.get_rank() == 0:
|
|
|
|
|
for k, v in train_stats.get().items():
|
|
|
|
|
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
|
|
|
|
@ -224,9 +228,15 @@ def train(config,
|
|
|
|
|
if dist.get_rank(
|
|
|
|
|
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
|
|
|
|
|
logs = train_stats.log()
|
|
|
|
|
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format(
|
|
|
|
|
epoch, epoch_num, global_step, logs, train_batch_elapse)
|
|
|
|
|
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f}s, batch_cost: {:.5f}s, samples: {}, ips: {:.5f}'.format(
|
|
|
|
|
epoch, epoch_num, global_step, logs, train_reader_cost /
|
|
|
|
|
print_batch_step, train_batch_cost / print_batch_step,
|
|
|
|
|
batch_sum, batch_sum / train_batch_cost)
|
|
|
|
|
logger.info(strs)
|
|
|
|
|
train_batch_cost = 0.0
|
|
|
|
|
train_reader_cost = 0.0
|
|
|
|
|
batch_sum = 0
|
|
|
|
|
batch_start = time.time()
|
|
|
|
|
# eval
|
|
|
|
|
if global_step > start_eval_step and \
|
|
|
|
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
|
|
|
|