|
|
|
@ -36,8 +36,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|
|
|
|
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
|
|
|
|
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
|
|
|
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
|
|
|
|
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
|
|
|
|
shard_equal_rows=True)
|
|
|
|
|
shuffle=de.Shuffle.FILES if do_shuffle == "true" else False,
|
|
|
|
|
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
|
|
|
|
|
ori_dataset_size = ds.get_dataset_size()
|
|
|
|
|
print('origin dataset size: ', ori_dataset_size)
|
|
|
|
|
new_size = ori_dataset_size
|
|
|
|
|