Change return val of GetNextRow

pull/5861/head
luoyang 4 years ago
parent f16ad7aa27
commit 581335453e

@ -327,9 +327,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
// Finish building vocab by triggering GetNextRow
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
if (vocab->vocab().empty()) {
MS_LOG(ERROR) << "Fail to build vocab.";
if (!iter->GetNextRow(&row)) {
return nullptr;
}
@ -1782,7 +1780,7 @@ bool BuildVocabDataset::ValidateParams() {
MS_LOG(ERROR) << "BuildVocab: vocab is null.";
return false;
}
if (top_k_ < 0) {
if (top_k_ <= 0) {
MS_LOG(ERROR) << "BuildVocab: top_k shoule be positive, but got: " << top_k_;
return false;
}

@ -22,25 +22,29 @@ namespace dataset {
namespace api {
// Get the next row from the data pipeline.
void Iterator::GetNextRow(TensorMap *row) {
bool Iterator::GetNextRow(TensorMap *row) {
Status rc = iterator_->GetNextAsMap(row);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
row->clear();
return false;
}
return true;
}
// Get the next row from the data pipeline.
void Iterator::GetNextRow(TensorVec *row) {
bool Iterator::GetNextRow(TensorVec *row) {
TensorRow tensor_row;
Status rc = iterator_->FetchNextTensorRow(&tensor_row);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
row->clear();
return false;
}
// Generate a vector as return
row->clear();
std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row));
return true;
}
// Shut down the data pipeline.

@ -56,12 +56,14 @@ class Iterator {
/// \brief Function to get the next row from the data pipeline.
/// \note Type of return data is a map(with column name).
/// \param[out] row - the output tensor row.
void GetNextRow(TensorMap *row);
/// \return Returns true if no error encountered else false.
bool GetNextRow(TensorMap *row);
/// \brief Function to get the next row from the data pipeline.
/// \note Type of return data is a vector(without column name).
/// \param[out] row - the output tensor row.
void GetNextRow(TensorVec *row);
/// \return Returns true if no error encountered else false.
bool GetNextRow(TensorVec *row);
/// \brief Function to shut down the data pipeline.
void Stop();

@ -92,7 +92,7 @@ Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vec
for (const WordType &word : words) {
if (std::count(words.begin(), words.end(), word) > 1) {
if (duplicate_word.find(word) == std::string::npos) {
duplicate_word = duplicate_word + ", " + word;
duplicate_word = duplicate_word.empty() ? duplicate_word + word : duplicate_word + ", " + word;
}
}
}
@ -102,10 +102,16 @@ Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vec
}
std::string duplicate_sp;
std::string existed_sp;
for (const WordType &sp : special_tokens) {
if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
if (duplicate_sp.find(sp) == std::string::npos) {
duplicate_sp = duplicate_sp + ", " + sp;
duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
}
}
if (std::count(words.begin(), words.end(), sp) >= 1) {
if (existed_sp.find(sp) == std::string::npos) {
existed_sp = existed_sp.empty() ? existed_sp + sp : existed_sp + ", " + sp;
}
}
}
@ -113,6 +119,10 @@ Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vec
MS_LOG(ERROR) << "special_tokens contains duplicate word: " << duplicate_sp;
RETURN_STATUS_UNEXPECTED("special_tokens contains duplicate word: " + duplicate_sp);
}
if (!existed_sp.empty()) {
MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " << existed_sp;
RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + existed_sp);
}
std::unordered_map<WordType, WordIdType> word2id;
@ -151,7 +161,7 @@ Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delim
for (const WordType &sp : special_tokens) {
if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
if (duplicate_sp.find(sp) == std::string::npos) {
duplicate_sp = duplicate_sp + ", " + sp;
duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
}
}
}
@ -179,12 +189,12 @@ Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delim
word = word.substr(0, word.find_first_of(delimiter));
}
if (word2id.find(word) != word2id.end()) {
MS_LOG(ERROR) << "duplicate word:" + word + ".";
RETURN_STATUS_UNEXPECTED("duplicate word:" + word + ".");
MS_LOG(ERROR) << "word_list contains duplicate word:" + word + ".";
RETURN_STATUS_UNEXPECTED("word_list contains duplicate word:" + word + ".");
}
if (specials.find(word) != specials.end()) {
MS_LOG(ERROR) << word + " is already in special_tokens.";
RETURN_STATUS_UNEXPECTED(word + " is already in special_tokens.");
MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " << word;
RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + word);
}
word2id[word] = word_id++;
// break if enough row is read, if vocab_size is smaller than 0

@ -158,7 +158,7 @@ TEST_F(MindDataTestVocab, TestVocabFromEmptyVector) {
TEST_F(MindDataTestVocab, TestVocabFromVectorFail1) {
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail1.";
// Build vocab from a vector of words with no special tokens
// Build vocab from a vector of words
std::vector<std::string> list = {"apple", "apple", "cat", "cat", "egg"};
std::vector<std::string> sp_tokens = {};
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
@ -170,7 +170,7 @@ TEST_F(MindDataTestVocab, TestVocabFromVectorFail1) {
TEST_F(MindDataTestVocab, TestVocabFromVectorFail2) {
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail2.";
// Build vocab from a vector of words with no special tokens
// Build vocab from a vector
std::vector<std::string> list = {"apple", "dog", "egg"};
std::vector<std::string> sp_tokens = {"<pad>", "<unk>", "<pad>", "<unk>", "<none>"};
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
@ -180,6 +180,18 @@ TEST_F(MindDataTestVocab, TestVocabFromVectorFail2) {
EXPECT_NE(s, Status::OK());
}
TEST_F(MindDataTestVocab, TestVocabFromVectorFail3) {
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail3.";
// Build vocab from a vector
std::vector<std::string> list = {"apple", "dog", "egg", "<unk>", "<pad>"};
std::vector<std::string> sp_tokens = {"<pad>", "<unk>"};
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
// Expected failure: special tokens are already existed in word_list
Status s = Vocab::BuildFromVector(list, sp_tokens, true, &vocab);
EXPECT_NE(s, Status::OK());
}
TEST_F(MindDataTestVocab, TestVocabFromFile) {
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFile.";
// Build vocab from local file
@ -218,8 +230,8 @@ TEST_F(MindDataTestVocab, TestVocabFromFileFail2) {
}
TEST_F(MindDataTestVocab, TestVocabFromFileFail3) {
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail2.";
// Build vocab from local file which is not exist
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail3.";
// Build vocab from local file
std::string vocab_dir = datasets_root_path_ + "/testVocab/vocab_list.txt";
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
@ -227,3 +239,14 @@ TEST_F(MindDataTestVocab, TestVocabFromFileFail3) {
Status s = Vocab::BuildFromFileCpp(vocab_dir, ",", -1, {"<unk>", "<unk>"}, true, &vocab);
EXPECT_NE(s, Status::OK());
}
TEST_F(MindDataTestVocab, TestVocabFromFileFail4) {
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail4.";
// Build vocab from local file
std::string vocab_dir = datasets_root_path_ + "/testVocab/vocab_list.txt";
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
// Expected failure: special_tokens and word_list contain duplicate word
Status s = Vocab::BuildFromFileCpp(vocab_dir, ",", -1, {"home"}, true, &vocab);
EXPECT_NE(s, Status::OK());
}

@ -271,6 +271,21 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) {
EXPECT_EQ(vocab, nullptr);
}
TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail4) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetFail4.";
// Create a TextFile dataset
std::string data_file = datasets_root_path_ + "/testVocab/words.txt";
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create vocab from dataset
// Expected failure: special tokens are already in the dataset
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()},
std::numeric_limits<int64_t>::max(), {"world"});
EXPECT_EQ(vocab, nullptr);
}
TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64.";
@ -318,4 +333,4 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) {
iter->GetNextRow(&row);
i++;
}
}
}

Loading…
Cancel
Save