|
|
|
@ -28,6 +28,7 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
|
|
|
|
|
IO_CHECK(input, output);
|
|
|
|
|
RETURN_UNEXPECTED_IF_NULL(vocab_);
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None string tensor received.");
|
|
|
|
|
|
|
|
|
|
std::vector<WordIdType> word_ids;
|
|
|
|
|
word_ids.reserve(input->Size());
|
|
|
|
|
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) {
|
|
|
|
@ -41,6 +42,8 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
|
|
|
|
|
|
|
|
|
|
// type cast to user's requirements if what user wants isn't int32_t
|
|
|
|
|
if ((*output)->type() != type_) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(),
|
|
|
|
|
"Lookup doesn't support string to string lookup. data_type needs to be numeric");
|
|
|
|
|
std::shared_ptr<Tensor> cast_to;
|
|
|
|
|
RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_));
|
|
|
|
|
*output = cast_to;
|
|
|
|
|