diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 883b2051b4..8dfc2a142d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -327,9 +327,7 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum // Finish building vocab by triggering GetNextRow std::unordered_map> row; - iter->GetNextRow(&row); - if (vocab->vocab().empty()) { - MS_LOG(ERROR) << "Fail to build vocab."; + if (!iter->GetNextRow(&row)) { return nullptr; } @@ -1782,7 +1780,7 @@ bool BuildVocabDataset::ValidateParams() { MS_LOG(ERROR) << "BuildVocab: vocab is null."; return false; } - if (top_k_ < 0) { + if (top_k_ <= 0) { MS_LOG(ERROR) << "BuildVocab: top_k shoule be positive, but got: " << top_k_; return false; } diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index d892135147..1559ca24c5 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -22,25 +22,29 @@ namespace dataset { namespace api { // Get the next row from the data pipeline. -void Iterator::GetNextRow(TensorMap *row) { +bool Iterator::GetNextRow(TensorMap *row) { Status rc = iterator_->GetNextAsMap(row); if (rc.IsError()) { MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; row->clear(); + return false; } + return true; } // Get the next row from the data pipeline. -void Iterator::GetNextRow(TensorVec *row) { +bool Iterator::GetNextRow(TensorVec *row) { TensorRow tensor_row; Status rc = iterator_->FetchNextTensorRow(&tensor_row); if (rc.IsError()) { MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; row->clear(); + return false; } // Generate a vector as return row->clear(); std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row)); + return true; } // Shut down the data pipeline. diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h index 443cca93eb..b6f476c54b 100644 --- a/mindspore/ccsrc/minddata/dataset/include/iterator.h +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -56,12 +56,14 @@ class Iterator { /// \brief Function to get the next row from the data pipeline. /// \note Type of return data is a map(with column name). /// \param[out] row - the output tensor row. - void GetNextRow(TensorMap *row); + /// \return Returns true if no error encountered else false. + bool GetNextRow(TensorMap *row); /// \brief Function to get the next row from the data pipeline. /// \note Type of return data is a vector(without column name). /// \param[out] row - the output tensor row. - void GetNextRow(TensorVec *row); + /// \return Returns true if no error encountered else false. + bool GetNextRow(TensorVec *row); /// \brief Function to shut down the data pipeline. void Stop(); diff --git a/mindspore/ccsrc/minddata/dataset/text/vocab.cc b/mindspore/ccsrc/minddata/dataset/text/vocab.cc index 35639d8b2d..719a4b3474 100644 --- a/mindspore/ccsrc/minddata/dataset/text/vocab.cc +++ b/mindspore/ccsrc/minddata/dataset/text/vocab.cc @@ -92,7 +92,7 @@ Status Vocab::BuildFromVector(const std::vector &words, const std::vec for (const WordType &word : words) { if (std::count(words.begin(), words.end(), word) > 1) { if (duplicate_word.find(word) == std::string::npos) { - duplicate_word = duplicate_word + ", " + word; + duplicate_word = duplicate_word.empty() ? duplicate_word + word : duplicate_word + ", " + word; } } } @@ -102,10 +102,16 @@ Status Vocab::BuildFromVector(const std::vector &words, const std::vec } std::string duplicate_sp; + std::string existed_sp; for (const WordType &sp : special_tokens) { if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) { if (duplicate_sp.find(sp) == std::string::npos) { - duplicate_sp = duplicate_sp + ", " + sp; + duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp; + } + } + if (std::count(words.begin(), words.end(), sp) >= 1) { + if (existed_sp.find(sp) == std::string::npos) { + existed_sp = existed_sp.empty() ? existed_sp + sp : existed_sp + ", " + sp; } } } @@ -113,6 +119,10 @@ Status Vocab::BuildFromVector(const std::vector &words, const std::vec MS_LOG(ERROR) << "special_tokens contains duplicate word: " << duplicate_sp; RETURN_STATUS_UNEXPECTED("special_tokens contains duplicate word: " + duplicate_sp); } + if (!existed_sp.empty()) { + MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " << existed_sp; + RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + existed_sp); + } std::unordered_map word2id; @@ -151,7 +161,7 @@ Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delim for (const WordType &sp : special_tokens) { if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) { if (duplicate_sp.find(sp) == std::string::npos) { - duplicate_sp = duplicate_sp + ", " + sp; + duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp; } } } @@ -179,12 +189,12 @@ Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delim word = word.substr(0, word.find_first_of(delimiter)); } if (word2id.find(word) != word2id.end()) { - MS_LOG(ERROR) << "duplicate word:" + word + "."; - RETURN_STATUS_UNEXPECTED("duplicate word:" + word + "."); + MS_LOG(ERROR) << "word_list contains duplicate word:" + word + "."; + RETURN_STATUS_UNEXPECTED("word_list contains duplicate word:" + word + "."); } if (specials.find(word) != specials.end()) { - MS_LOG(ERROR) << word + " is already in special_tokens."; - RETURN_STATUS_UNEXPECTED(word + " is already in special_tokens."); + MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " << word; + RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + word); } word2id[word] = word_id++; // break if enough row is read, if vocab_size is smaller than 0 diff --git a/tests/ut/cpp/dataset/build_vocab_test.cc b/tests/ut/cpp/dataset/build_vocab_test.cc index 3edb4a8449..86f7a9a377 100644 --- a/tests/ut/cpp/dataset/build_vocab_test.cc +++ b/tests/ut/cpp/dataset/build_vocab_test.cc @@ -158,7 +158,7 @@ TEST_F(MindDataTestVocab, TestVocabFromEmptyVector) { TEST_F(MindDataTestVocab, TestVocabFromVectorFail1) { MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail1."; - // Build vocab from a vector of words with no special tokens + // Build vocab from a vector of words std::vector list = {"apple", "apple", "cat", "cat", "egg"}; std::vector sp_tokens = {}; std::shared_ptr vocab = std::make_shared(); @@ -170,7 +170,7 @@ TEST_F(MindDataTestVocab, TestVocabFromVectorFail1) { TEST_F(MindDataTestVocab, TestVocabFromVectorFail2) { MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail2."; - // Build vocab from a vector of words with no special tokens + // Build vocab from a vector std::vector list = {"apple", "dog", "egg"}; std::vector sp_tokens = {"", "", "", "", ""}; std::shared_ptr vocab = std::make_shared(); @@ -180,6 +180,18 @@ TEST_F(MindDataTestVocab, TestVocabFromVectorFail2) { EXPECT_NE(s, Status::OK()); } +TEST_F(MindDataTestVocab, TestVocabFromVectorFail3) { + MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail3."; + // Build vocab from a vector + std::vector list = {"apple", "dog", "egg", "", ""}; + std::vector sp_tokens = {"", ""}; + std::shared_ptr vocab = std::make_shared(); + + // Expected failure: special tokens are already existed in word_list + Status s = Vocab::BuildFromVector(list, sp_tokens, true, &vocab); + EXPECT_NE(s, Status::OK()); +} + TEST_F(MindDataTestVocab, TestVocabFromFile) { MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFile."; // Build vocab from local file @@ -218,8 +230,8 @@ TEST_F(MindDataTestVocab, TestVocabFromFileFail2) { } TEST_F(MindDataTestVocab, TestVocabFromFileFail3) { - MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail2."; - // Build vocab from local file which is not exist + MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail3."; + // Build vocab from local file std::string vocab_dir = datasets_root_path_ + "/testVocab/vocab_list.txt"; std::shared_ptr vocab = std::make_shared(); @@ -227,3 +239,14 @@ TEST_F(MindDataTestVocab, TestVocabFromFileFail3) { Status s = Vocab::BuildFromFileCpp(vocab_dir, ",", -1, {"", ""}, true, &vocab); EXPECT_NE(s, Status::OK()); } + +TEST_F(MindDataTestVocab, TestVocabFromFileFail4) { + MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail4."; + // Build vocab from local file + std::string vocab_dir = datasets_root_path_ + "/testVocab/vocab_list.txt"; + std::shared_ptr vocab = std::make_shared(); + + // Expected failure: special_tokens and word_list contain duplicate word + Status s = Vocab::BuildFromFileCpp(vocab_dir, ",", -1, {"home"}, true, &vocab); + EXPECT_NE(s, Status::OK()); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_vocab.cc b/tests/ut/cpp/dataset/c_api_dataset_vocab.cc index 3d15b3cc9e..11926e5bb8 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_vocab.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_vocab.cc @@ -271,6 +271,21 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) { EXPECT_EQ(vocab, nullptr); } +TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail4) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetFail4."; + + // Create a TextFile dataset + std::string data_file = datasets_root_path_ + "/testVocab/words.txt"; + std::shared_ptr ds = TextFile({data_file}, 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create vocab from dataset + // Expected failure: special tokens are already in the dataset + std::shared_ptr vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits::max()}, + std::numeric_limits::max(), {"world"}); + EXPECT_EQ(vocab, nullptr); +} + TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64."; @@ -318,4 +333,4 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) { iter->GetNextRow(&row); i++; } -} \ No newline at end of file +}