add switch for data shuffle

pull/3798/head
yoonlee666 5 years ago
parent 3119aaccfd
commit 9bdece71d4

@ -133,14 +133,17 @@ def run_classifier():
"""run classifier task"""
parser = argparse.ArgumentParser(description="run classifier")
parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend")
parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: "
"[MCC, Spearman_correlation, "
"Accuracy], default is accuracy")
parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false")
parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false")
parser.add_argument("--assessment_method", type=str, default="accuracy",
help="assessment_method including [MCC, Spearman_correlation, Accuracy], default is accuracy")
parser.add_argument("--do_train", type=str, default="false", help="Enable train, default is false")
parser.add_argument("--do_eval", type=str, default="false", help="Enable eval, default is false")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.")
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle, default is true")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle, default is false")
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
@ -182,7 +185,8 @@ def run_classifier():
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method,
data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path)
schema_file_path=args_opt.schema_file_path,
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true":
@ -197,7 +201,8 @@ def run_classifier():
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method,
data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path)
schema_file_path=args_opt.schema_file_path,
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path)
if __name__ == "__main__":

@ -150,6 +150,10 @@ def run_ner():
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.")
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle, default is true")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle, default is false")
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark")
parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark")
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
@ -208,7 +212,8 @@ def run_ner():
if args_opt.do_train.lower() == "true":
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path)
schema_file_path=args_opt.schema_file_path,
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true":
@ -222,7 +227,8 @@ def run_ner():
if args_opt.do_eval.lower() == "true":
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path)
schema_file_path=args_opt.schema_file_path,
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index)

@ -140,6 +140,10 @@ def run_squad():
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.")
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle, default is true")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle, default is false")
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path")
parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json")
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
@ -186,7 +190,8 @@ def run_squad():
if args_opt.do_train.lower() == "true":
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path)
schema_file_path=args_opt.schema_file_path,
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true":
if save_finetune_checkpoint_path == "":
@ -199,7 +204,8 @@ def run_squad():
if args_opt.do_eval.lower() == "true":
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path, is_training=False)
schema_file_path=args_opt.schema_file_path, is_training=False,
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path,
load_finetune_checkpoint_path, bert_net_cfg.seq_length)

@ -34,6 +34,8 @@ python ${PROJECT_DIR}/../run_classifier.py \
--device_id=0 \
--epoch_num=1 \
--num_class=2 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--save_finetune_checkpoint_path="" \
--load_pretrain_checkpoint_path="" \
--load_finetune_checkpoint_path="" \

@ -35,6 +35,8 @@ python ${PROJECT_DIR}/../run_ner.py \
--device_id=0 \
--epoch_num=1 \
--num_class=2 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--vocab_file_path="" \
--label2id_file_path="" \
--save_finetune_checkpoint_path="" \

@ -33,6 +33,8 @@ python ${PROJECT_DIR}/../run_squad.py \
--device_id=0 \
--epoch_num=1 \
--num_class=2 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--vocab_file_path="" \
--eval_json_path="" \
--save_finetune_checkpoint_path="" \

@ -34,7 +34,6 @@ class Accuracy():
logit_id = np.argmax(logits, axis=-1)
self.acc_num += np.sum(labels == logit_id)
self.total_num += len(labels)
print("=========================accuracy is ", self.acc_num / self.total_num)
class F1():
'''

@ -53,11 +53,11 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
data_file_path=None, schema_file_path=None):
data_file_path=None, schema_file_path=None, do_shuffle=True):
"""create finetune or evaluation dataset"""
type_cast_op = C.TypeCast(mstype.int32)
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"])
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], shuffle=do_shuffle)
if assessment_method == "Spearman_correlation":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
@ -76,11 +76,11 @@ def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy
def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
data_file_path=None, schema_file_path=None):
data_file_path=None, schema_file_path=None, do_shuffle=True):
"""create finetune or evaluation dataset"""
type_cast_op = C.TypeCast(mstype.int32)
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"])
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], shuffle=do_shuffle)
if assessment_method == "Spearman_correlation":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
@ -98,14 +98,15 @@ def create_classification_dataset(batch_size=1, repeat_count=1, assessment_metho
return ds
def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, is_training=True):
def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None,
is_training=True, do_shuffle=True):
"""create finetune or evaluation dataset"""
type_cast_op = C.TypeCast(mstype.int32)
if is_training:
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids",
"start_positions", "end_positions",
"unique_ids", "is_impossible"])
columns_list=["input_ids", "input_mask", "segment_ids", "start_positions",
"end_positions", "unique_ids", "is_impossible"],
shuffle=do_shuffle)
ds = ds.map(input_columns="start_positions", operations=type_cast_op)
ds = ds.map(input_columns="end_positions", operations=type_cast_op)
else:

Loading…
Cancel
Save