|  |  | @ -133,14 +133,17 @@ def run_classifier(): | 
			
		
	
		
		
			
				
					
					|  |  |  |     """run classifier task""" |  |  |  |     """run classifier task""" | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser = argparse.ArgumentParser(description="run classifier") |  |  |  |     parser = argparse.ArgumentParser(description="run classifier") | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") |  |  |  |     parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " |  |  |  |     parser.add_argument("--assessment_method", type=str, default="accuracy", | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                                                   "[MCC, Spearman_correlation, " |  |  |  |                         help="assessment_method including [MCC, Spearman_correlation, Accuracy], default is accuracy") | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                                                   "Accuracy], default is accuracy") |  |  |  |     parser.add_argument("--do_train", type=str, default="false", help="Enable train, default is false") | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") |  |  |  |     parser.add_argument("--do_eval", type=str, default="false", help="Enable eval, default is false") | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") |  |  |  |  | 
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") |  |  |  |     parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") |  |  |  |     parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") |  |  |  |     parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     parser.add_argument("--train_data_shuffle", type=str, default="true", | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         help="Enable train data shuffle, default is true") | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     parser.add_argument("--eval_data_shuffle", type=str, default="false", | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         help="Enable eval data shuffle, default is false") | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") |  |  |  |     parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") |  |  |  |     parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") | 
			
		
	
		
		
			
				
					
					|  |  |  |     parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") |  |  |  |     parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") | 
			
		
	
	
		
		
			
				
					|  |  | @ -182,7 +185,8 @@ def run_classifier(): | 
			
		
	
		
		
			
				
					
					|  |  |  |         ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, |  |  |  |         ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                            assessment_method=assessment_method, |  |  |  |                                            assessment_method=assessment_method, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                            data_file_path=args_opt.train_data_file_path, |  |  |  |                                            data_file_path=args_opt.train_data_file_path, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                            schema_file_path=args_opt.schema_file_path) |  |  |  |                                            schema_file_path=args_opt.schema_file_path, | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                                            do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) | 
			
		
	
		
		
			
				
					
					|  |  |  |         do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) |  |  |  |         do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         if args_opt.do_eval.lower() == "true": |  |  |  |         if args_opt.do_eval.lower() == "true": | 
			
		
	
	
		
		
			
				
					|  |  | @ -197,7 +201,8 @@ def run_classifier(): | 
			
		
	
		
		
			
				
					
					|  |  |  |         ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, |  |  |  |         ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                            assessment_method=assessment_method, |  |  |  |                                            assessment_method=assessment_method, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                            data_file_path=args_opt.eval_data_file_path, |  |  |  |                                            data_file_path=args_opt.eval_data_file_path, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                            schema_file_path=args_opt.schema_file_path) |  |  |  |                                            schema_file_path=args_opt.schema_file_path, | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                                            do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) | 
			
		
	
		
		
			
				
					
					|  |  |  |         do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path) |  |  |  |         do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | if __name__ == "__main__": |  |  |  | if __name__ == "__main__": | 
			
		
	
	
		
		
			
				
					|  |  | 
 |