From 4307c1fa61712af2fc44f3959edf7a947bef15c6 Mon Sep 17 00:00:00 2001 From: dessyang Date: Wed, 29 Jul 2020 15:13:06 -0400 Subject: [PATCH] change the column order and add drop_reminder option to make this script compatible with BertCLS model --- .../clue_classification_dataset_process.py | 40 +++++-------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py b/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py index 1e27fe0352..042bc0a9c6 100755 --- a/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py +++ b/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py @@ -26,8 +26,8 @@ import mindspore.dataset.text as text import mindspore.dataset.transforms.c_transforms as ops -def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, - data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): +def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, + max_seq_len=128, batch_size=64, drop_remainder=True): """Process TNEWS dataset""" ### Loading TNEWS from CLUEDataset assert data_usage in ['train', 'eval', 'test'] @@ -61,26 +61,17 @@ def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup) dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], - columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate()) + columns_order=["text_ids", "mask_ids", "label_id"], operations=ops.Duplicate()) dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"], - columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) + columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate()) dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0)) - dataset = dataset.batch(batch_size) - label = [] - text_ids = [] - mask_ids = [] - segment_ids = [] - for data in dataset: - label.append(data[0]) - text_ids.append(data[1]) - mask_ids.append(data[2]) - segment_ids.append(data[3]) - return label, text_ids, mask_ids, segment_ids + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + return dataset -def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, - data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): +def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, + max_seq_len=128, batch_size=64, drop_remainder=True): """Process CMNLI dataset""" ### Loading CMNLI from CLUEDataset assert data_usage in ['train', 'eval', 'test'] @@ -138,16 +129,7 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) ### Generating mask_ids dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], - columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) + columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate()) dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) - dataset = dataset.batch(batch_size) - label = [] - text_ids = [] - mask_ids = [] - segment_ids = [] - for data in dataset: - label.append(data[0]) - text_ids.append(data[1]) - mask_ids.append(data[2]) - segment_ids.append(data[3]) - return label, text_ids, mask_ids, segment_ids + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + return dataset