日志符合benchmark规范

release/2.0-rc1-0
WenmuZhou 4 years ago
parent d4facfe4e5
commit e822901522

@ -185,12 +185,15 @@ def train(config,
for epoch in range(start_epoch, epoch_num):
if epoch > 0:
train_dataloader = build_dataloader(config, 'Train', device, logger)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
for idx, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - batch_start
if idx >= len(train_dataloader):
break
lr = optimizer.get_lr()
t1 = time.time()
images = batch[0]
preds = model(images)
loss = loss_class(preds, batch)
@ -198,6 +201,10 @@ def train(config,
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
train_batch_cost += time.time() - batch_start
batch_sum += len(images)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
@ -213,9 +220,6 @@ def train(config,
metirc = eval_class.get_metric()
train_stats.update(metirc)
t2 = time.time()
train_batch_elapse = t2 - t1
if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items():
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
@ -224,9 +228,15 @@ def train(config,
if dist.get_rank(
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format(
epoch, epoch_num, global_step, logs, train_batch_elapse)
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f}s, batch_cost: {:.5f}s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, train_batch_cost / print_batch_step,
batch_sum, batch_sum / train_batch_cost)
logger.info(strs)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:

Loading…
Cancel
Save