From 2692e3cc3dafba1ece829b64a07ba0836e98c660 Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Thu, 11 Mar 2021 12:04:22 -0500 Subject: [PATCH] fix pk sampler's get_dataset error due to num_class unavaiable at pre-runtime --- .../datasetops/source/sampler/pk_sampler.h | 5 +++ .../datasetops/source/sampler/sampler.cc | 2 ++ .../datasetops/source/sampler/sampler.h | 2 +- .../source/sampler/sequential_sampler.cc | 2 ++ .../source/sampler/subset_sampler.cc | 2 ++ .../ir/datasetops/source/celeba_node.cc | 3 ++ .../ir/datasetops/source/cifar100_node.cc | 4 ++- .../ir/datasetops/source/cifar10_node.cc | 5 +++ .../engine/ir/datasetops/source/coco_node.cc | 3 ++ .../ir/datasetops/source/image_folder_node.cc | 3 ++ .../ir/datasetops/source/manifest_node.cc | 3 ++ .../engine/ir/datasetops/source/mnist_node.cc | 3 ++ .../engine/ir/datasetops/source/voc_node.cc | 3 ++ tests/ut/cpp/dataset/ir_sampler_test.cc | 4 +-- .../python/dataset/test_datasets_cifarop.py | 32 +++++++++++++++++++ 15 files changed, 72 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h index 3c97f5f0ad..e1163505d9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -66,6 +66,11 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief PK cannot return an exact value because num_classes is not known until runtime, hence -1 is used + /// \param[out] num_rows + /// \return -1, which means PKSampler doesn't know how much data + int64_t CalculateNumSamples(int64_t num_rows) override { return -1; } + private: bool shuffle_; uint32_t seed_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index 485255e277..95d1bf4320 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -140,6 +140,8 @@ int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) { int64_t child_num_rows = num_rows; if (!child_.empty()) { child_num_rows = child_[0]->CalculateNumSamples(num_rows); + // return -1 if child_num_rows is undetermined + if (child_num_rows == -1) return child_num_rows; } return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index 1f43fc9df4..3a51ffcfc4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -108,7 +108,7 @@ class SamplerRT { // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of // num_samples_ - // @return number of samples + // @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler) virtual int64_t CalculateNumSamples(int64_t num_rows); // setter for num or records in the dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index 70ff27dc05..5968cd29c4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -109,6 +109,8 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) { int64_t child_num_rows = num_rows; if (!child_.empty()) { child_num_rows = child_[0]->CalculateNumSamples(num_rows); + // return -1 if child_num_rows is undetermined + if (child_num_rows == -1) return child_num_rows; } int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; // For this sampler we need to take start_index into account. Because for example in the case we are given n rows diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc index 0d02c0c210..1fe703a8d8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc @@ -139,6 +139,8 @@ int64_t SubsetSamplerRT::CalculateNumSamples(int64_t num_rows) { int64_t child_num_rows = num_rows; if (!child_.empty()) { child_num_rows = child_[0]->CalculateNumSamples(num_rows); + // return -1 if child_num_rows is undetermined + if (child_num_rows == -1) return child_num_rows; } int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; res = std::min(res, static_cast(indices_.size())); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 7c3f62488e..8f451f41a5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -144,6 +144,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr &size std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } *dataset_size = sample_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index bb4866e076..991be365ce 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -95,7 +95,9 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr &si std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); - + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index e53a91f171..3616a4a7b1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -88,12 +88,17 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr &siz *dataset_size = dataset_size_; return Status::OK(); } + int64_t num_rows, sample_size; RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } + *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index eb36270758..715c15e240 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -151,6 +151,9 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr &size_g std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index fb3fc65efb..ad6d539a3d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -100,6 +100,9 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index cd2d76dc91..2d33cc7567 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -123,6 +123,9 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr &si std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 4b7c86ac38..c0e59d195f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -88,6 +88,9 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr &size_ std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 2e5a455728..ab1a92c92a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -139,6 +139,9 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr &size_ge std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/tests/ut/cpp/dataset/ir_sampler_test.cc b/tests/ut/cpp/dataset/ir_sampler_test.cc index 42f2e4359e..9bb773c735 100644 --- a/tests/ut/cpp/dataset/ir_sampler_test.cc +++ b/tests/ut/cpp/dataset/ir_sampler_test.cc @@ -36,7 +36,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) { sampl = std::make_shared(3, false, 0); EXPECT_NE(sampl, nullptr); sampl->SamplerBuild(&sampler_rt); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), -1); sampl = std::make_shared(false, 12); EXPECT_NE(sampl, nullptr); @@ -98,7 +98,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) { std::shared_ptr sampler_rt6; sampl6->SamplerBuild(&sampler_rt6); sampler_rt6->AddChild(sampler_rt5); - EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); + EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), -1); } TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) { diff --git a/tests/ut/python/dataset/test_datasets_cifarop.py b/tests/ut/python/dataset/test_datasets_cifarop.py index 5a16f4ef84..65e77d3d59 100644 --- a/tests/ut/python/dataset/test_datasets_cifarop.py +++ b/tests/ut/python/dataset/test_datasets_cifarop.py @@ -501,6 +501,35 @@ def test_cifar_exception_file_path(): assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) +def test_cifar10_pk_sampler_get_dataset_size(): + """ + Test Cifar10Dataset with PKSampler and get_dataset_size + """ + sampler = ds.PKSampler(3) + data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) + num_iter = 0 + ds_sz = data.get_dataset_size() + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + + assert ds_sz == num_iter == 30 + + +def test_cifar10_with_chained_sampler_get_dataset_size(): + """ + Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size + """ + sampler = ds.SequentialSampler(start_index=0, num_samples=5) + child_sampler = ds.PKSampler(4) + sampler.add_child(child_sampler) + data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) + num_iter = 0 + ds_sz = data.get_dataset_size() + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + assert ds_sz == num_iter == 5 + + if __name__ == '__main__': test_cifar10_content_check() test_cifar10_basic() @@ -517,3 +546,6 @@ if __name__ == '__main__': test_cifar_usage() test_cifar_exception_file_path() + + test_cifar10_with_chained_sampler_get_dataset_size() + test_cifar10_pk_sampler_get_dataset_size()