diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index dbe4c032d2..2c2bfd88ee 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -402,7 +402,7 @@ std::shared_ptr Mnist(const std::string &dataset_dir, const std::s // Function to overload "+" operator to concat two datasets std::shared_ptr operator+(const std::shared_ptr &datasets1, const std::shared_ptr &datasets2) { - return std::make_shared(std::vector({datasets2, datasets1})); + return std::make_shared(std::vector({datasets1, datasets2})); } // Function to create a TextFileDataset. diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index 569e752cbc..0209fd1ec7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -73,6 +73,51 @@ Status ConcatNode::ValidateParams() { return Status::OK(); } +// Get Dataset size +Status ConcatNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + + // calculate the total size of all nodes + int64_t total_dataset_size = 0; + int64_t child_dataset_size = 0; + for (int idx = 0; idx < children_.size(); idx++) { + if (children_flag_and_nums_.empty() || children_flag_and_nums_[idx].second == 0) { + children_[idx]->GetDatasetSize(size_getter, false, &child_dataset_size); + total_dataset_size += child_dataset_size; + } else { + total_dataset_size += children_flag_and_nums_[idx].second; + } + } + + // calculate the size of the shard + int64_t shard_dataset_size = 0; + if (sampler_ != nullptr) { + std::shared_ptr sampler_rt = + std::static_pointer_cast(sampler_->SamplerBuild()); + sampler_rt->SetNumRowsInDataset(total_dataset_size); + sampler_rt->InitSampler(); + + // (total_size % num_shards != 0) & shard_id >= (remainder) ? CalculateNumSamples()-1 : CalculateNumSamples() + // example: 23 rows, 10 shards --> shard sizes = {3,3,3,2,2,2,2,2,2,2} + if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) > 0 && + sampler_rt->GetDeviceID() >= (sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum())) { + shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()) - 1; + } else { + shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()); + } + } else { + shard_dataset_size = total_dataset_size; + } + + *dataset_size = shard_dataset_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + Status ConcatNode::Build(std::vector> *const node_ops) { if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { node_ops->push_back(std::make_shared(connector_que_size_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h index e2dc960754..c20f30d4c1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h @@ -55,6 +55,15 @@ class ConcatNode : public DatasetNode { /// \return Status Status::OK() if build successfully Status Build(std::vector> *const node_ops) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + /// \brief Parameters validation /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 1789849e6f..da139f6a4f 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -261,8 +261,8 @@ class Dataset : public std::enable_shared_from_this { /// \param[in] datasets List of shared pointers to the dataset that should be concatenated together /// \return Shared pointer to the current ConcatDataset std::shared_ptr Concat(const std::vector> &datasets) { - std::vector> all_datasets = datasets; - all_datasets.push_back(shared_from_this()); + std::vector> all_datasets{shared_from_this()}; + all_datasets.insert(std::end(all_datasets), std::begin(datasets), std::end(datasets)); return std::make_shared(all_datasets); } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index bb7da16795..9d08e655fe 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -372,7 +372,7 @@ class Dataset: Args: condition_name (str): The condition name that is used to toggle sending next row. num_batch (int): the number of batches without blocking at the start of each epoch. - callback (function): The callback funciton that will be invoked when sync_update is called. + callback (function): The callback function that will be invoked when sync_update is called. Returns: SyncWaitDataset, dataset added a blocking condition. @@ -398,7 +398,7 @@ class Dataset: 1. Make a shuffle buffer that contains the first buffer_size rows. 2. Randomly select an element from the shuffle buffer to be the next row - propogated to the child node. + propagated to the child node. 3. Get the next row (if any) from the parent node and put it in the shuffle buffer. 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer. @@ -1649,8 +1649,7 @@ class MappableDataset(SourceDataset): def add_sampler(self, new_sampler): # note: By adding a sampler, the sampled IDs will flow to new_sampler # after first passing through the current samplers attached to this dataset. - if self.dataset_size is not None: - self.dataset_size = None + self.dataset_size = None new_sampler.add_child(self.sampler) self.sampler = new_sampler @@ -1676,8 +1675,7 @@ class MappableDataset(SourceDataset): raise TypeError("Input sampler can not be None.") if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): raise TypeError("Input sampler is not an instance of a sampler.") - if self.dataset_size is not None: - self.dataset_size = None + self.dataset_size = None self.sampler = self.sampler.child_sampler self.add_sampler(new_sampler) @@ -1718,7 +1716,7 @@ class MappableDataset(SourceDataset): - The sum of split sizes < K, the difference will be added to the first split. - The sum of split sizes > K, the difference will be removed from the first large - enough split such that it will have atleast 1 row after removing the difference. + enough split such that it will have at least 1 row after removing the difference. randomize (bool, optional): Determines whether or not to split the data randomly (default=True). If True, the data will be randomly split. Otherwise, each split will be created with @@ -2647,6 +2645,8 @@ class ConcatDataset(Dataset): if sampler.get_num_samples() is not None: raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!") + self.dataset_size = None + self._sampler = _select_sampler(None, sampler, None, None, None) cumulative_samples_nums = 0 for index, child in enumerate(self.children): diff --git a/tests/ut/cpp/dataset/tree_modifying_function_test.cc b/tests/ut/cpp/dataset/tree_modifying_function_test.cc index a7b7075485..d440fcc504 100644 --- a/tests/ut/cpp/dataset/tree_modifying_function_test.cc +++ b/tests/ut/cpp/dataset/tree_modifying_function_test.cc @@ -53,7 +53,7 @@ TEST_F(MindDataTestTreeModifying, AppendChild) { std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds3 = ds1->Take(10); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); Status rc; std::shared_ptr root = ds4->IRNode(); @@ -110,7 +110,7 @@ TEST_F(MindDataTestTreeModifying, InsertChildAt01) { std::shared_ptr ds3 = ds1->Take(10); std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ds5->Repeat(4); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); Status rc; std::shared_ptr root = ds4->IRNode(); auto ir_tree = std::make_shared(); @@ -173,7 +173,7 @@ TEST_F(MindDataTestTreeModifying, InsertChildAt04) { std::shared_ptr ds3 = ds1->Take(10); std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ds5->Repeat(4); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); Status rc; std::shared_ptr root = ds4->IRNode(); auto ir_tree = std::make_shared(); @@ -253,7 +253,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove01) { std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds3 = ds1->Take(10); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); Status rc; std::shared_ptr root = ds4->IRNode(); @@ -280,7 +280,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove02) { std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds3 = ds1->Take(10); - std::shared_ptr ds4 = ds2 + ds3; // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3 + ds2; Status rc; std::shared_ptr root = ds4->IRNode(); @@ -307,7 +307,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove03) { std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds3 = ds1->Take(10); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); Status rc; std::shared_ptr root = ds4->IRNode(); @@ -372,9 +372,9 @@ TEST_F(MindDataTestTreeModifying, Drop01) { std::shared_ptr ds9 = ds8->Skip(1); std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); std::shared_ptr ds6 = ds4->Take(13); - std::shared_ptr ds10 = ds6 + ds9; + std::shared_ptr ds10 = ds9 + ds6; Status rc; std::shared_ptr root = ds10->IRNode(); @@ -437,9 +437,9 @@ TEST_F(MindDataTestTreeModifying, Drop03) { std::shared_ptr ds9 = ds8->Skip(1); std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); std::shared_ptr ds6 = ds4->Take(13); - std::shared_ptr ds10 = ds6 + ds9; + std::shared_ptr ds10 = ds9 + ds6; Status rc; std::shared_ptr root = ds10->IRNode(); @@ -487,11 +487,11 @@ TEST_F(MindDataTestTreeModifying, Drop04) { std::shared_ptr ds9 = ds8->Skip(1); std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); - std::shared_ptr ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!! - std::shared_ptr ds10 = ds6 + ds9; + std::shared_ptr ds6 = ds5->Concat({ds4, ds1}); + std::shared_ptr ds10 = ds9 + ds6; Status rc; std::shared_ptr root = ds10->IRNode(); @@ -548,8 +548,8 @@ TEST_F(MindDataTestTreeModifying, Drop05) { std::shared_ptr ds4 = ds3->Skip(1); std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); - std::shared_ptr ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!! - std::shared_ptr ds10 = ds6 + ds9; + std::shared_ptr ds6 = ds5->Concat({ds4, ds1}); + std::shared_ptr ds10 = ds9 + ds6; Status rc; std::shared_ptr root = ds10->IRNode(); @@ -603,11 +603,11 @@ TEST_F(MindDataTestTreeModifying, Drop06) { std::shared_ptr ds9 = ds8->Skip(1); std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); - std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds4 = ds3->Concat({ds2}); std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); - std::shared_ptr ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!! - std::shared_ptr ds10 = ds6 + ds9; + std::shared_ptr ds6 = ds5->Concat({ds4, ds1}); // ds1 is put after (ds5, ds4)!!! + std::shared_ptr ds10 = ds9 + ds6; Status rc; std::shared_ptr root = ds10->IRNode(); diff --git a/tests/ut/python/dataset/test_concat.py b/tests/ut/python/dataset/test_concat.py index eff8c7905e..dd2c668965 100644 --- a/tests/ut/python/dataset/test_concat.py +++ b/tests/ut/python/dataset/test_concat.py @@ -33,12 +33,19 @@ def generator_10(): for i in range(3, 10): yield (np.array([i]),) + # In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19 def generator_20(): for i in range(10, 20): yield (np.array([i]),) +# In generator_29 dataset: Number of rows is 9; its values are 20, 21, 22 ... 28 +def generator_29(): + for i in range(20, 29): + yield (np.array([i]),) + + def test_concat_01(): """ Test concat: test concat 2 datasets that have the same column name and data type @@ -316,7 +323,7 @@ def test_concat_13(): def test_concat_14(): """ - Test concat: create dataset with different dataset folder, and do diffrent operation then concat + Test concat: Testing concat on two different source datasets with different dataset operations. """ logger.info("test_concat_14") DATA_DIR = "../data/dataset/testPK/data" @@ -365,6 +372,63 @@ def test_concat_15(): assert sum([1 for _ in data3]) == 47 +def test_concat_16(): + """ + Test concat: test get_dataset_size on nested concats + """ + logger.info("test_concat_16") + DATA_DIR = "../data/dataset/testPK/data" + DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] + + data1 = ds.ImageFolderDataset(DATA_DIR) + data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) + + data3 = ds.GeneratorDataset(generator, ["col1"]) + data4 = ds.GeneratorDataset(generator_10, ["col1"]) + + data5 = data1 + data2 + data6 = data3 + data4 + data7 = data5 + data6 + + ds.config.set_seed(1) + + # 57 is the total size of all 4 leaf datasets + assert data7.get_dataset_size() == 57 + + +def test_concat_17(): + """ + Test concat: test get_dataset_size on nested concats (with sampler) + """ + logger.info("test_concat_17") + + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + data3 = ds.GeneratorDataset(generator_20, ["col1"]) + data4 = ds.GeneratorDataset(generator_29, ["col1"]) + + data5 = data1 + data2 + data6 = data3 + data4 + data7 = data5 + data6 + + ds.config.set_seed(1) + shard_num = 10 + counter = 0 + + for i in range(shard_num): + distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None) + data7.use_sampler(distributed_sampler) + iter_counter = 0 + for _ in data7.create_dict_iterator(num_epochs=1, output_numpy=True): + counter += 1 + iter_counter += 1 + assert data7.get_dataset_size() == iter_counter + + # 29 is the total size of all 4 leaf datasets + assert counter == 29 + + if __name__ == "__main__": test_concat_01() test_concat_02() @@ -381,3 +445,5 @@ if __name__ == "__main__": test_concat_13() test_concat_14() test_concat_15() + test_concat_16() + test_concat_17()