parent
							
								
									ba01949d8f
								
							
						
					
					
						commit
						25ab2ef303
					
				| @ -0,0 +1,7 @@ | ||||
| add_subdirectory(kernels) | ||||
| 
 | ||||
| add_library(nlp OBJECT | ||||
|         vocab.cc | ||||
|         ) | ||||
| 
 | ||||
| add_dependencies(nlp nlp-kernels) | ||||
| @ -0,0 +1,3 @@ | ||||
| add_library(nlp-kernels OBJECT | ||||
|         lookup_op.cc | ||||
|         ) | ||||
| @ -0,0 +1,52 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| #include "dataset/nlp/kernels/lookup_op.h" | ||||
| 
 | ||||
| #include <string> | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace dataset { | ||||
| 
 | ||||
| LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id) | ||||
|     : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} | ||||
| 
 | ||||
| Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
|   RETURN_UNEXPECTED_IF_NULL(vocab_); | ||||
|   CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor"); | ||||
|   std::vector<WordIdType> word_ids; | ||||
|   word_ids.reserve(input->Size()); | ||||
|   for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) { | ||||
|     word_ids.push_back(vocab_->Lookup(std::string(*itr), default_id_)); | ||||
|   } | ||||
| 
 | ||||
|   RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, | ||||
|                                         reinterpret_cast<unsigned char *>(word_ids.data()))); | ||||
|   return Status::OK(); | ||||
| } | ||||
| Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | ||||
|   CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); | ||||
|   CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); | ||||
|   outputs[0] = type_; | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| void LookupOp::Print(std::ostream &out) const { | ||||
|   out << "LookupOp: " | ||||
|       << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; | ||||
| } | ||||
| 
 | ||||
| }  // namespace dataset
 | ||||
| }  // namespace mindspore
 | ||||
| @ -0,0 +1,62 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef DATASET_NLP_KERNELS_LOOKUP_OP_H_ | ||||
| #define DATASET_NLP_KERNELS_LOOKUP_OP_H_ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <vector> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "dataset/core/tensor.h" | ||||
| #include "dataset/kernels/tensor_op.h" | ||||
| #include "dataset/util/status.h" | ||||
| #include "dataset/nlp/vocab.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace dataset { | ||||
| class LookupOp : public TensorOp { | ||||
|  public: | ||||
|   // constructor for lookup, takes in a vocab object
 | ||||
|   // @param std::shared_ptr<Vocab> vocab -
 | ||||
|   // @param WordIdType default_id, id to lookup if a word is not in vocab
 | ||||
|   explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = Vocab::kSpecialTokens::unk); | ||||
| 
 | ||||
|   // perform actual lookup on each tensor
 | ||||
|   // @param const std::shared_ptr<Tensor> &input
 | ||||
|   // @param std::shared_ptr<Tensor> *output
 | ||||
|   // @return error code
 | ||||
|   Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | ||||
| 
 | ||||
|   // print method
 | ||||
|   // @param std::ostream out
 | ||||
|   void Print(std::ostream &out) const override; | ||||
| 
 | ||||
|   // @param std::vector<DataType> &inputs -
 | ||||
|   // @param std::vector<DataType> &outputs -
 | ||||
|   // @return error code
 | ||||
|   Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Vocab> vocab_; | ||||
|   WordIdType default_id_; | ||||
|   DataType type_;  // type of tensor after lookup
 | ||||
| }; | ||||
| 
 | ||||
| }  // namespace dataset
 | ||||
| }  // namespace mindspore
 | ||||
| 
 | ||||
| #endif  // DATASET_NLP_KERNELS_LOOKUP_OP_H_
 | ||||
| @ -0,0 +1,101 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| #include <fstream> | ||||
| #include <map> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "dataset/nlp/vocab.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace dataset { | ||||
| Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { | ||||
|   word2id_ = std::move(word2id); | ||||
|   id2word_.resize(word2id_.size()); | ||||
|   for (auto p : word2id_) { | ||||
|     id2word_[p.second - kSpecialTokens::num_tokens] = p.first; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { | ||||
|   auto itr = word2id_.find(word); | ||||
|   return itr == word2id_.end() ? default_id : itr->second; | ||||
| } | ||||
| WordType Vocab::Lookup(WordIdType id) const { | ||||
|   if (id < kSpecialTokens::num_tokens) { | ||||
|     return reserved_token_str_[id]; | ||||
|   } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { | ||||
|     return reserved_token_str_[kSpecialTokens::unk]; | ||||
|   } else { | ||||
|     return id2word_[id - kSpecialTokens::num_tokens]; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab) { | ||||
|   std::unordered_map<WordType, WordIdType> word2id; | ||||
|   WordIdType word_id = kSpecialTokens::num_tokens; | ||||
|   for (auto word : words) { | ||||
|     const std::string s = py::str(word); | ||||
|     CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s); | ||||
|     word2id[s] = word_id++; | ||||
|   } | ||||
|   *vocab = std::make_shared<Vocab>(std::move(word2id)); | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | ||||
|                             std::shared_ptr<Vocab> *vocab) { | ||||
|   std::unordered_map<WordType, WordIdType> word2id; | ||||
|   WordIdType word_id = kSpecialTokens::num_tokens; | ||||
|   std::fstream handle(path, std::ios::in); | ||||
|   CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); | ||||
|   std::string word; | ||||
|   while (std::getline(handle, word)) { | ||||
|     if (!delimiter.empty()) { | ||||
|       // if delimiter is not found, find_first_of would return std::string::npos which is -1
 | ||||
|       word = word.substr(0, word.find_first_of(delimiter)); | ||||
|     } | ||||
|     CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word); | ||||
|     word2id[word] = word_id++; | ||||
|     // break if enough row is read, if vocab_size is smaller than 0
 | ||||
|     if (word_id == vocab_size + kSpecialTokens::num_tokens) break; | ||||
|   } | ||||
|   *vocab = std::make_shared<Vocab>(std::move(word2id)); | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) { | ||||
|   std::unordered_map<WordType, WordIdType> word2id; | ||||
|   std::map<WordIdType, WordType> id2word; | ||||
|   for (auto p : words) { | ||||
|     WordIdType word_id = py::reinterpret_borrow<py::int_>(p.second); | ||||
|     if (word_id < kSpecialTokens::num_tokens) continue;  // skip id that are reserved
 | ||||
|     std::string word = py::str(p.first); | ||||
|     CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word); | ||||
|     id2word[word_id] = word; | ||||
|   } | ||||
| 
 | ||||
|   WordIdType cnt = kSpecialTokens::num_tokens; | ||||
|   for (auto p : id2word) { | ||||
|     CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2"); | ||||
|     word2id[p.second] = p.first; | ||||
|   } | ||||
| 
 | ||||
|   *vocab = std::make_shared<Vocab>(std::move(word2id)); | ||||
|   return Status::OK(); | ||||
| } | ||||
| const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"}; | ||||
| }  // namespace dataset
 | ||||
| }  // namespace mindspore
 | ||||
| @ -0,0 +1,88 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef DATASET_NLP_VOCAB_H_ | ||||
| #define DATASET_NLP_VOCAB_H_ | ||||
| 
 | ||||
| #include <string> | ||||
| #include <memory> | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "dataset/util/status.h" | ||||
| #include "pybind11/pybind11.h" | ||||
| #include "pybind11/stl.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace dataset { | ||||
| namespace py = pybind11; | ||||
| 
 | ||||
| using WordIdType = int32_t; | ||||
| using WordType = std::string; | ||||
| 
 | ||||
| class Vocab { | ||||
|  public: | ||||
|   // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous
 | ||||
|   // @param const py::dict &words - a dictionary containing word, word id pair.
 | ||||
|   // @param std::shared_ptr<Vocab> *vocab - return value, vocab object
 | ||||
|   // @return error code
 | ||||
|   static Status BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab); | ||||
| 
 | ||||
|   // Build a vocab from a python list, id will be assigned automatically, start from 2
 | ||||
|   // @param const py::list &words - a list of string, used to build vocab, id starts from 2
 | ||||
|   // @param std::shared_ptr<Vocab> *vocab - return value, vocab object
 | ||||
|   // @return error code
 | ||||
|   static Status BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab); | ||||
| 
 | ||||
|   // Build a vocab from reading a vocab file, id are automatically assigned, start from 2
 | ||||
|   // @param std::string &path - path to vocab file , each line is assumed to contain 1 word
 | ||||
|   // @param std::string &delimiter - delimiter to break each line with
 | ||||
|   // @param int32_t vocab_size - number of words to read from file
 | ||||
|   // @param std::shared_ptr<Vocab> *vocab - return value, vocab object
 | ||||
|   // @return error code
 | ||||
|   static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | ||||
|                               std::shared_ptr<Vocab> *vocab); | ||||
| 
 | ||||
|   // Lookup the id of a word, if word doesn't exist in vocab, return default_id
 | ||||
|   // @param const WordType word - word to look up
 | ||||
|   // @param WordIdType default_id - word id to return to user when its not in the vocab
 | ||||
|   // @return WordIdType, word_id
 | ||||
|   WordIdType Lookup(const WordType &word, WordIdType default_id) const; | ||||
| 
 | ||||
|   // reverse lookup, lookup the word based on its id
 | ||||
|   // @param WordIdType id - word id to lookup to
 | ||||
|   // @return WordType the word
 | ||||
|   WordType Lookup(WordIdType id) const; | ||||
| 
 | ||||
|   // constructor, shouldn't be called directly, can't be private due to std::make_unique()
 | ||||
|   // @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map
 | ||||
|   explicit Vocab(std::unordered_map<WordType, WordIdType> map); | ||||
| 
 | ||||
|   // enum type that holds all special tokens, add more if needed
 | ||||
|   enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 }; | ||||
| 
 | ||||
|   // reversed lookup table for the reserved tokens
 | ||||
|   static const std::vector<WordType> reserved_token_str_; | ||||
| 
 | ||||
|  private: | ||||
|   std::unordered_map<WordType, WordIdType> word2id_; | ||||
|   std::vector<WordType> id2word_;  // reverse lookup
 | ||||
| }; | ||||
| 
 | ||||
| }  // namespace dataset
 | ||||
| }  // namespace mindspore
 | ||||
| 
 | ||||
| #endif  // DATASET_NLP_VOCAB_H_
 | ||||
| @ -0,0 +1,19 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| """ | ||||
| mindspore.dataset.text | ||||
| """ | ||||
| 
 | ||||
| from .c_transforms import * | ||||
| @ -0,0 +1,77 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
| c transforms for all text related operators | ||||
| """ | ||||
| 
 | ||||
| import mindspore._c_dataengine as cde | ||||
| from .validators import check_lookup, check_from_list, check_from_dict, check_from_file | ||||
| 
 | ||||
| 
 | ||||
| class Vocab(cde.Vocab): | ||||
|     """ | ||||
|         Vocab object that is used for lookup word | ||||
|     Args: | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         pass | ||||
| 
 | ||||
|     @classmethod | ||||
|     @check_from_list | ||||
|     def from_list(cls, word_list): | ||||
|         """ | ||||
|            build a vocab object from a list of word | ||||
|         Args: | ||||
|             word_list(list): a list of string where each element is a word | ||||
|         """ | ||||
|         return super().from_list(word_list) | ||||
| 
 | ||||
|     @classmethod | ||||
|     @check_from_file | ||||
|     def from_file(cls, file_path, delimiter=None, vocab_size=None): | ||||
|         """ | ||||
|             build a vocab object from a list of word | ||||
|         Args: | ||||
|             file_path(str): path to the file which contains the vocab list | ||||
|             delimiter(None, str): a delimiter to break up each line in file, the first element is taken to be the word | ||||
|             vocab_size(None, int): number of words to read from file_path | ||||
|         """ | ||||
|         return super().from_file(file_path, delimiter, vocab_size) | ||||
| 
 | ||||
|     @classmethod | ||||
|     @check_from_dict | ||||
|     def from_dict(cls, word_dict): | ||||
|         """ | ||||
|             build a vocab object from a dict. | ||||
|         Args: | ||||
|             word_dict(dict): dict contains word, id pairs. id should start from 2 and continuous | ||||
|         """ | ||||
|         return super().from_dict(word_dict) | ||||
| 
 | ||||
| 
 | ||||
| class Lookup(cde.LookupOp): | ||||
|     """ | ||||
|         Lookup operator that looks up a word to an id | ||||
|     Args: | ||||
|         vocab(Vocab): a Vocab object | ||||
|         unknown(None,int): default id to lookup a word that is out of vocab | ||||
|     """ | ||||
| 
 | ||||
|     @check_lookup | ||||
|     def __init__(self, vocab, unknown=None): | ||||
|         if unknown is None: | ||||
|             super().__init__(vocab) | ||||
|         else: | ||||
|             super().__init__(vocab, unknown) | ||||
| @ -0,0 +1,108 @@ | ||||
| # Copyright 2019 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """ | ||||
| validators for text ops | ||||
| """ | ||||
| 
 | ||||
| from functools import wraps | ||||
| import mindspore._c_dataengine as cde | ||||
| 
 | ||||
| 
 | ||||
| def check_lookup(method): | ||||
|     """A wrapper that wrap a parameter checker to the original function(crop operation).""" | ||||
| 
 | ||||
|     @wraps(method) | ||||
|     def new_method(self, *args, **kwargs): | ||||
|         vocab, unknown = (list(args) + 2 * [None])[:2] | ||||
|         if "vocab" in kwargs: | ||||
|             vocab = kwargs.get("vocab") | ||||
|         if "unknown" in kwargs: | ||||
|             unknown = kwargs.get("unknown") | ||||
|         if unknown is not None: | ||||
|             assert isinstance(unknown, int) and unknown >= 0, "unknown needs to be a non-negative integer" | ||||
| 
 | ||||
|         assert isinstance(vocab, cde.Vocab), "vocab is not an instance of cde.Vocab" | ||||
| 
 | ||||
|         kwargs["vocab"] = vocab | ||||
|         kwargs["unknown"] = unknown | ||||
|         return method(self, **kwargs) | ||||
| 
 | ||||
|     return new_method | ||||
| 
 | ||||
| 
 | ||||
| def check_from_file(method): | ||||
|     """A wrapper that wrap a parameter checker to the original function(crop operation).""" | ||||
| 
 | ||||
|     @wraps(method) | ||||
|     def new_method(self, *args, **kwargs): | ||||
|         file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3] | ||||
|         if "file_path" in kwargs: | ||||
|             file_path = kwargs.get("file_path") | ||||
|         if "delimiter" in kwargs: | ||||
|             delimiter = kwargs.get("delimiter") | ||||
|         if "vocab_size" in kwargs: | ||||
|             vocab_size = kwargs.get("vocab_size") | ||||
| 
 | ||||
|         assert isinstance(file_path, str), "file_path needs to be str" | ||||
|         if delimiter is not None: | ||||
|             assert isinstance(delimiter, str), "delimiter needs to be str" | ||||
|         else: | ||||
|             delimiter = "" | ||||
|         if vocab_size is not None: | ||||
|             assert isinstance(vocab_size, int) and vocab_size > 0, "vocab size needs to be a positive integer" | ||||
|         else: | ||||
|             vocab_size = -1 | ||||
|         kwargs["file_path"] = file_path | ||||
|         kwargs["delimiter"] = delimiter | ||||
|         kwargs["vocab_size"] = vocab_size | ||||
|         return method(self, **kwargs) | ||||
| 
 | ||||
|     return new_method | ||||
| 
 | ||||
| 
 | ||||
| def check_from_list(method): | ||||
|     """A wrapper that wrap a parameter checker to the original function(crop operation).""" | ||||
| 
 | ||||
|     @wraps(method) | ||||
|     def new_method(self, *args, **kwargs): | ||||
|         word_list, = (list(args) + [None])[:1] | ||||
|         if "word_list" in kwargs: | ||||
|             word_list = kwargs.get("word_list") | ||||
|         assert isinstance(word_list, list), "word_list needs to be a list of words" | ||||
|         for word in word_list: | ||||
|             assert isinstance(word, str), "each word in word list needs to be type str" | ||||
| 
 | ||||
|         kwargs["word_list"] = word_list | ||||
|         return method(self, **kwargs) | ||||
| 
 | ||||
|     return new_method | ||||
| 
 | ||||
| 
 | ||||
| def check_from_dict(method): | ||||
|     """A wrapper that wrap a parameter checker to the original function(crop operation).""" | ||||
| 
 | ||||
|     @wraps(method) | ||||
|     def new_method(self, *args, **kwargs): | ||||
|         word_dict, = (list(args) + [None])[:1] | ||||
|         if "word_dict" in kwargs: | ||||
|             word_dict = kwargs.get("word_dict") | ||||
|         assert isinstance(word_dict, dict), "word_dict needs to be a list of word,id pairs" | ||||
|         for word, word_id in word_dict.items(): | ||||
|             assert isinstance(word, str), "each word in word_dict needs to be type str" | ||||
|             assert isinstance(word_id, int) and word_id >= 0, "each word id needs to be positive integer" | ||||
|         kwargs["word_dict"] = word_dict | ||||
|         return method(self, **kwargs) | ||||
| 
 | ||||
|     return new_method | ||||
| @ -0,0 +1,14 @@ | ||||
| not,1 | ||||
| all,2 | ||||
| those,3 | ||||
| who,4 | ||||
| wonder,5 | ||||
| are,6 | ||||
| lost,7 | ||||
| Tolkein,8 | ||||
| home,9 | ||||
| is,10 | ||||
| behind,11 | ||||
| world,12 | ||||
| ahead,13 | ||||
| the,14 | ||||
| @ -0,0 +1,6 @@ | ||||
| home | ||||
| is | ||||
| behind | ||||
| the | ||||
| world | ||||
| ahead | ||||
| @ -0,0 +1,47 @@ | ||||
| import mindspore.dataset as ds | ||||
| import mindspore.dataset.text as text | ||||
| 
 | ||||
| # this file contains "home is behind the world head" each word is 1 line | ||||
| DATA_FILE = "../data/dataset/testVocab/words.txt" | ||||
| VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt" | ||||
| 
 | ||||
| 
 | ||||
| def test_from_list(): | ||||
|     vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" ")) | ||||
|     lookup = text.Lookup(vocab) | ||||
|     data = ds.TextFileDataset(DATA_FILE, shuffle=False) | ||||
|     data = data.map(input_columns=["text"], operations=lookup) | ||||
|     ind = 0 | ||||
|     res = [2, 1, 4, 5, 6, 7] | ||||
|     for d in data.create_dict_iterator(): | ||||
|         assert d["text"] == res[ind], ind | ||||
|         ind += 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_from_file(): | ||||
|     vocab = text.Vocab.from_file(VOCAB_FILE, ",") | ||||
|     lookup = text.Lookup(vocab) | ||||
|     data = ds.TextFileDataset(DATA_FILE, shuffle=False) | ||||
|     data = data.map(input_columns=["text"], operations=lookup) | ||||
|     ind = 0 | ||||
|     res = [10, 11, 12, 15, 13, 14] | ||||
|     for d in data.create_dict_iterator(): | ||||
|         assert d["text"] == res[ind], ind | ||||
|         ind += 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_from_dict(): | ||||
|     vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6}) | ||||
|     lookup = text.Lookup(vocab, 6)  # default value is -1 | ||||
|     data = ds.TextFileDataset(DATA_FILE, shuffle=False) | ||||
|     data = data.map(input_columns=["text"], operations=lookup) | ||||
|     res = [3, 6, 2, 4, 5, 6] | ||||
|     ind = 0 | ||||
|     for d in data.create_dict_iterator(): | ||||
|         assert d["text"] == res[ind], ind | ||||
|         ind += 1 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     test_from_list() | ||||
|     test_from_file() | ||||
|     test_from_dict() | ||||
					Loading…
					
					
				
		Reference in new issue