From dcb758bc77b3df93cac4df76ff19f55c14aac79f Mon Sep 17 00:00:00 2001 From: luoyang Date: Fri, 6 Nov 2020 15:06:52 +0800 Subject: [PATCH] [MD] Skip returning IR when meet nullptr in ConcatDataset --- .../ccsrc/minddata/dataset/api/datasets.cc | 7 ++-- .../ut/cpp/dataset/c_api_dataset_ops_test.cc | 40 +++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 639e47e1ad..dd8422b3fe 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -478,9 +478,10 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset( ConcatDataset::ConcatDataset(const std::vector> &datasets) { std::vector> all_datasets; - (void)std::transform( - datasets.begin(), datasets.end(), std::back_inserter(all_datasets), - [](std::shared_ptr dataset) -> std::shared_ptr { return dataset->IRNode(); }); + (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets), + [](std::shared_ptr dataset) -> std::shared_ptr { + return (dataset != nullptr) ? dataset->IRNode() : nullptr; + }); auto ds = std::make_shared(all_datasets); diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index ff74363e30..815d7df2b3 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -340,6 +340,46 @@ TEST_F(MindDataTestPipeline, TestConcatFail2) { EXPECT_EQ(iter, nullptr); } +TEST_F(MindDataTestPipeline, TestConcatFail3) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatFail3."; + // This case is expected to fail because the input dataset is nullptr. + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Concat operation on the ds + // Input dataset to concat is null + ds = ds->Concat({nullptr}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid Op input + EXPECT_EQ(iter, nullptr); +} + +TEST_F(MindDataTestPipeline, TestConcatFail4) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatFail4."; + // This case is expected to fail because the input dataset is nullptr. + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Concat operation on the ds + // Input dataset to concat is null + ds = ds + nullptr; + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid Op input + EXPECT_EQ(iter, nullptr); +} + TEST_F(MindDataTestPipeline, TestConcatSuccess) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess.";