|
|
@ -174,6 +174,7 @@ def train(config,
|
|
|
|
best_model_dict = {main_indicator: 0}
|
|
|
|
best_model_dict = {main_indicator: 0}
|
|
|
|
best_model_dict.update(pre_best_model_dict)
|
|
|
|
best_model_dict.update(pre_best_model_dict)
|
|
|
|
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
|
|
|
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
|
|
|
|
|
|
|
model_average = False
|
|
|
|
model.train()
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
|
|
if 'start_epoch' in best_model_dict:
|
|
|
|
if 'start_epoch' in best_model_dict:
|
|
|
@ -197,6 +198,7 @@ def train(config,
|
|
|
|
if config['Architecture']['algorithm'] == "SRN":
|
|
|
|
if config['Architecture']['algorithm'] == "SRN":
|
|
|
|
others = batch[-4:]
|
|
|
|
others = batch[-4:]
|
|
|
|
preds = model(images, others)
|
|
|
|
preds = model(images, others)
|
|
|
|
|
|
|
|
model_average = True
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
preds = model(images)
|
|
|
|
preds = model(images)
|
|
|
|
loss = loss_class(preds, batch)
|
|
|
|
loss = loss_class(preds, batch)
|
|
|
@ -242,12 +244,13 @@ def train(config,
|
|
|
|
# eval
|
|
|
|
# eval
|
|
|
|
if global_step > start_eval_step and \
|
|
|
|
if global_step > start_eval_step and \
|
|
|
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
|
|
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
|
|
|
model_average = paddle.optimizer.ModelAverage(
|
|
|
|
if model_average:
|
|
|
|
0.15,
|
|
|
|
Model_Average = paddle.incubate.optimizer.ModelAverage(
|
|
|
|
parameters=model.parameters(),
|
|
|
|
0.15,
|
|
|
|
min_average_window=10000,
|
|
|
|
parameters=model.parameters(),
|
|
|
|
max_average_window=15625)
|
|
|
|
min_average_window=10000,
|
|
|
|
model_average.apply()
|
|
|
|
max_average_window=15625)
|
|
|
|
|
|
|
|
Model_Average.apply()
|
|
|
|
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
|
|
|
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
|
|
|
eval_class)
|
|
|
|
eval_class)
|
|
|
|
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
|
|
|
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
|
|
|