|
|
|
@ -133,14 +133,17 @@ def run_classifier():
|
|
|
|
|
"""run classifier task"""
|
|
|
|
|
parser = argparse.ArgumentParser(description="run classifier")
|
|
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend")
|
|
|
|
|
parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: "
|
|
|
|
|
"[MCC, Spearman_correlation, "
|
|
|
|
|
"Accuracy], default is accuracy")
|
|
|
|
|
parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false")
|
|
|
|
|
parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false")
|
|
|
|
|
parser.add_argument("--assessment_method", type=str, default="accuracy",
|
|
|
|
|
help="assessment_method including [MCC, Spearman_correlation, Accuracy], default is accuracy")
|
|
|
|
|
parser.add_argument("--do_train", type=str, default="false", help="Enable train, default is false")
|
|
|
|
|
parser.add_argument("--do_eval", type=str, default="false", help="Enable eval, default is false")
|
|
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
|
|
|
|
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
|
|
|
|
|
parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.")
|
|
|
|
|
parser.add_argument("--train_data_shuffle", type=str, default="true",
|
|
|
|
|
help="Enable train data shuffle, default is true")
|
|
|
|
|
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
|
|
|
|
help="Enable eval data shuffle, default is false")
|
|
|
|
|
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
|
|
|
|
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
|
|
|
|
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
|
|
|
@ -182,7 +185,8 @@ def run_classifier():
|
|
|
|
|
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
|
|
|
|
assessment_method=assessment_method,
|
|
|
|
|
data_file_path=args_opt.train_data_file_path,
|
|
|
|
|
schema_file_path=args_opt.schema_file_path)
|
|
|
|
|
schema_file_path=args_opt.schema_file_path,
|
|
|
|
|
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
|
|
|
|
|
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
|
|
|
|
|
|
|
|
|
|
if args_opt.do_eval.lower() == "true":
|
|
|
|
@ -197,7 +201,8 @@ def run_classifier():
|
|
|
|
|
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
|
|
|
|
assessment_method=assessment_method,
|
|
|
|
|
data_file_path=args_opt.eval_data_file_path,
|
|
|
|
|
schema_file_path=args_opt.schema_file_path)
|
|
|
|
|
schema_file_path=args_opt.schema_file_path,
|
|
|
|
|
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
|
|
|
|
do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|