From 06b3d482525e96770489692ccb01b06da42bcb43 Mon Sep 17 00:00:00 2001 From: jonyguo Date: Mon, 15 Jun 2020 15:54:38 +0800 Subject: [PATCH] 1. add set_dataset_size for MindDataset 2. modify parameter dupe_factor from 5 to 10 --- example/nlp_to_mindrecord/zhwiki/run.sh | 2 +- .../nlp_to_mindrecord/zhwiki/run_simple.sh | 2 +- mindspore/dataset/engine/datasets.py | 36 ++++++++++++------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/example/nlp_to_mindrecord/zhwiki/run.sh b/example/nlp_to_mindrecord/zhwiki/run.sh index 431ff54c65..a057031e6b 100644 --- a/example/nlp_to_mindrecord/zhwiki/run.sh +++ b/example/nlp_to_mindrecord/zhwiki/run.sh @@ -83,7 +83,7 @@ for index in $(seq 0 $file_list_len); do --max_predictions_per_seq=20 \ --masked_lm_prob=0.15 \ --random_seed=12345 \ - --dupe_factor=5 >/tmp/${output_filename[$index]}.log 2>&1 & + --dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & # user defined process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` echo "Total task: ${#file_list[*]}, processing: ${process_count}" if [ $process_count -ge $avaiable_core_size ]; then diff --git a/example/nlp_to_mindrecord/zhwiki/run_simple.sh b/example/nlp_to_mindrecord/zhwiki/run_simple.sh index 7574e851d0..20c1d98d66 100644 --- a/example/nlp_to_mindrecord/zhwiki/run_simple.sh +++ b/example/nlp_to_mindrecord/zhwiki/run_simple.sh @@ -44,4 +44,4 @@ python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched --max_predictions_per_seq=20 \ --masked_lm_prob=0.15 \ --random_seed=12345 \ ---dupe_factor=5 +--dupe_factor=10 # user defined diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index c51905208e..fe29738cb8 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2671,20 +2671,30 @@ class MindDataset(SourceDataset): Return: Number, number of batches. """ - if self.load_dataset: - dataset_file = [self.dataset_file] - else: - dataset_file = self.dataset_file - num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) - if self.partitions is not None and self.partitions[0] > 0: - if num_rows % self.partitions[0] == 0: - num_rows = num_rows // self.partitions[0] + if self._dataset_size is None: + if self.load_dataset: + dataset_file = [self.dataset_file] else: - if self.num_padded > 0: - raise RuntimeError( - "Dataset size plus number of padded samples is not divisible by number of shards.") - num_rows = num_rows // self.partitions[0] + 1 - return num_rows + dataset_file = self.dataset_file + num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) + if self.partitions is not None and self.partitions[0] > 0: + if num_rows % self.partitions[0] == 0: + num_rows = num_rows // self.partitions[0] + else: + if self.num_padded > 0: + raise RuntimeError( + "Dataset size plus number of padded samples is not divisible by number of shards.") + num_rows = num_rows // self.partitions[0] + 1 + return num_rows + return self._dataset_size + + # manually set dataset_size as a tempoary solution. + def set_dataset_size(self, value): + logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") + if value >= 0: + self._dataset_size = value + else: + raise ValueError('set dataset_size with negative value {}'.format(value)) def is_shuffled(self): if self.shuffle_option is None: