change the column order and add drop_reminder option to make this script compatible with BertCLS model

pull/3693/head
dessyang 5 years ago
parent 6f70146153
commit 4307c1fa61

@ -26,8 +26,8 @@ import mindspore.dataset.text as text
import mindspore.dataset.transforms.c_transforms as ops
def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path,
data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64):
def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False,
max_seq_len=128, batch_size=64, drop_remainder=True):
"""Process TNEWS dataset"""
### Loading TNEWS from CLUEDataset
assert data_usage in ['train', 'eval', 'test']
@ -61,26 +61,17 @@ def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path,
dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup)
dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0))
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"],
columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate())
columns_order=["text_ids", "mask_ids", "label_id"], operations=ops.Duplicate())
dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32))
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"],
columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate())
columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate())
dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0))
dataset = dataset.batch(batch_size)
label = []
text_ids = []
mask_ids = []
segment_ids = []
for data in dataset:
label.append(data[0])
text_ids.append(data[1])
mask_ids.append(data[2])
segment_ids.append(data[3])
return label, text_ids, mask_ids, segment_ids
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
return dataset
def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path,
data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64):
def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False,
max_seq_len=128, batch_size=64, drop_remainder=True):
"""Process CMNLI dataset"""
### Loading CMNLI from CLUEDataset
assert data_usage in ['train', 'eval', 'test']
@ -138,16 +129,7 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path,
dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0))
### Generating mask_ids
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"],
columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate())
columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate())
dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32))
dataset = dataset.batch(batch_size)
label = []
text_ids = []
mask_ids = []
segment_ids = []
for data in dataset:
label.append(data[0])
text_ids.append(data[1])
mask_ids.append(data[2])
segment_ids.append(data[3])
return label, text_ids, mask_ids, segment_ids
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
return dataset

Loading…
Cancel
Save