@ -92,8 +92,10 @@ def run_predistill():
dataset_size = dataset . get_dataset_size ( )
dataset_size = dataset . get_dataset_size ( )
if args_opt . enable_data_sink == ' true ' :
if args_opt . enable_data_sink == ' true ' :
repeat_count = args_opt . td_phase1_epoch_size * dataset . get_dataset_size ( ) / / args_opt . data_sink_steps
repeat_count = args_opt . td_phase1_epoch_size * dataset . get_dataset_size ( ) / / args_opt . data_sink_steps
time_monitor_steps = args_opt . data_sink_steps
else :
else :
repeat_count = args_opt . td_phase1_epoch_size
repeat_count = args_opt . td_phase1_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg . optimizer_cfg
optimizer_cfg = cfg . optimizer_cfg
@ -110,10 +112,10 @@ def run_predistill():
{ ' order_params ' : params } ]
{ ' order_params ' : params } ]
optimizer = AdamWeightDecay ( group_params , learning_rate = lr_schedule , eps = optimizer_cfg . AdamWeightDecay . eps )
optimizer = AdamWeightDecay ( group_params , learning_rate = lr_schedule , eps = optimizer_cfg . AdamWeightDecay . eps )
callback = [ TimeMonitor ( dataset_size ) , LossCallBack ( ) , ModelSaveCkpt ( netwithloss . bert ,
callback = [ TimeMonitor ( time_monitor_steps ) , LossCallBack ( ) , ModelSaveCkpt ( netwithloss . bert ,
args_opt . save_ckpt_step ,
args_opt . save_ckpt_step ,
args_opt . max_ckpt_num ,
args_opt . max_ckpt_num ,
td_phase1_save_ckpt_dir ) ]
td_phase1_save_ckpt_dir ) ]
update_cell = DynamicLossScaleUpdateCell ( loss_scale_value = cfg . loss_scale_value ,
update_cell = DynamicLossScaleUpdateCell ( loss_scale_value = cfg . loss_scale_value ,
scale_factor = cfg . scale_factor ,
scale_factor = cfg . scale_factor ,
scale_window = cfg . scale_window )
scale_window = cfg . scale_window )
@ -147,8 +149,10 @@ def run_task_distill(ckpt_file):
dataset_size = train_dataset . get_dataset_size ( )
dataset_size = train_dataset . get_dataset_size ( )
if args_opt . enable_data_sink == ' true ' :
if args_opt . enable_data_sink == ' true ' :
repeat_count = args_opt . td_phase2_epoch_size * train_dataset . get_dataset_size ( ) / / args_opt . data_sink_steps
repeat_count = args_opt . td_phase2_epoch_size * train_dataset . get_dataset_size ( ) / / args_opt . data_sink_steps
time_monitor_steps = args_opt . data_sink_steps
else :
else :
repeat_count = args_opt . td_phase2_epoch_size
repeat_count = args_opt . td_phase2_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg . optimizer_cfg
optimizer_cfg = cfg . optimizer_cfg
@ -170,14 +174,14 @@ def run_task_distill(ckpt_file):
device_num , rank , args_opt . do_shuffle ,
device_num , rank , args_opt . do_shuffle ,
args_opt . eval_data_dir , args_opt . schema_dir )
args_opt . eval_data_dir , args_opt . schema_dir )
if args_opt . do_eval . lower ( ) == " true " :
if args_opt . do_eval . lower ( ) == " true " :
callback = [ TimeMonitor ( dataset_size ) , LossCallBack ( ) ,
callback = [ TimeMonitor ( time_monitor_steps ) , LossCallBack ( ) ,
ModelSaveCkpt ( netwithloss . bert ,
ModelSaveCkpt ( netwithloss . bert ,
args_opt . save_ckpt_step ,
args_opt . save_ckpt_step ,
args_opt . max_ckpt_num ,
args_opt . max_ckpt_num ,
td_phase2_save_ckpt_dir ) ,
td_phase2_save_ckpt_dir ) ,
EvalCallBack ( netwithloss . bert , eval_dataset ) ]
EvalCallBack ( netwithloss . bert , eval_dataset ) ]
else :
else :
callback = [ TimeMonitor ( dataset_size ) , LossCallBack ( ) ,
callback = [ TimeMonitor ( time_monitor_steps ) , LossCallBack ( ) ,
ModelSaveCkpt ( netwithloss . bert ,
ModelSaveCkpt ( netwithloss . bert ,
args_opt . save_ckpt_step ,
args_opt . save_ckpt_step ,
args_opt . max_ckpt_num ,
args_opt . max_ckpt_num ,