|
|
|
@ -184,7 +184,14 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
if ckpt_file == '':
|
|
|
|
|
raise ValueError("Student ckpt file should not be None")
|
|
|
|
|
cfg = phase2_cfg
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == "Ascend":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
elif args_opt.device_target == "GPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Target error, GPU or Ascend is supported.")
|
|
|
|
|
|
|
|
|
|
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
|
|
|
|
load_student_checkpoint_path = ckpt_file
|
|
|
|
|
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
|
|
|
|