@ -219,6 +219,13 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
epoch_num = config [ ' Global ' ] [ ' epoch_num ' ]
print_batch_step = config [ ' Global ' ] [ ' print_batch_step ' ]
eval_batch_step = config [ ' Global ' ] [ ' eval_batch_step ' ]
start_eval_step = 0
if type ( eval_batch_step ) == list and len ( eval_batch_step ) > = 2 :
start_eval_step = eval_batch_step [ 0 ]
eval_batch_step = eval_batch_step [ 1 ]
logger . info (
" During the training process, after the {} th iteration, an evaluation is run every {} iterations " .
format ( start_eval_step , eval_batch_step ) )
save_epoch_step = config [ ' Global ' ] [ ' save_epoch_step ' ]
save_model_dir = config [ ' Global ' ] [ ' save_model_dir ' ]
if not os . path . exists ( save_model_dir ) :
@ -246,7 +253,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
t2 = time . time ( )
train_batch_elapse = t2 - t1
train_stats . update ( stats )
if train_batch_id > 0 and train_batch_id \
if train_batch_id > start_eval_step and ( train_batch_id - start_eval_step ) \
% print_batch_step == 0 :
logs = train_stats . log ( )
strs = ' epoch: {} , iter: {} , {} , time: {:.3f} ' . format (
@ -286,6 +293,13 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
epoch_num = config [ ' Global ' ] [ ' epoch_num ' ]
print_batch_step = config [ ' Global ' ] [ ' print_batch_step ' ]
eval_batch_step = config [ ' Global ' ] [ ' eval_batch_step ' ]
start_eval_step = 0
if type ( eval_batch_step ) == list and len ( eval_batch_step ) > = 2 :
start_eval_step = eval_batch_step [ 0 ]
eval_batch_step = eval_batch_step [ 1 ]
logger . info (
" During the training process, after the {} th iteration, an evaluation is run every {} iterations " .
format ( start_eval_step , eval_batch_step ) )
save_epoch_step = config [ ' Global ' ] [ ' save_epoch_step ' ]
save_model_dir = config [ ' Global ' ] [ ' save_model_dir ' ]
if not os . path . exists ( save_model_dir ) :
@ -324,7 +338,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_batch_elapse = t2 - t1
stats = { ' loss ' : loss , ' acc ' : acc }
train_stats . update ( stats )
if train_batch_id > 0 and train_batch_id \
if train_batch_id > start_eval_step and ( train_batch_id - start_eval_step ) \
% print_batch_step == 0 :
logs = train_stats . log ( )
strs = ' epoch: {} , iter: {} , lr: {:.6f} , {} , time: {:.3f} ' . format (