|
|
|
@ -26,7 +26,7 @@ LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id)
|
|
|
|
|
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
|
|
|
|
IO_CHECK(input, output);
|
|
|
|
|
RETURN_UNEXPECTED_IF_NULL(vocab_);
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor.");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None string tensor received.");
|
|
|
|
|
std::vector<WordIdType> word_ids;
|
|
|
|
|
word_ids.reserve(input->Size());
|
|
|
|
|
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) {
|
|
|
|
@ -34,7 +34,7 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
|
|
|
|
|
word_ids.emplace_back(word_id == Vocab::kNoTokenExists ? default_id_ : word_id);
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(
|
|
|
|
|
word_ids.back() != Vocab::kNoTokenExists,
|
|
|
|
|
"Lookup Error: token" + std::string(*itr) + "doesn't exist in vocab and no unknown token is specified.");
|
|
|
|
|
"Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_,
|
|
|
|
@ -42,8 +42,8 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match.");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type.");
|
|
|
|
|
outputs[0] = type_;
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|