diff --git a/mindspore/lite/src/common/string_util.cc b/mindspore/lite/src/common/string_util.cc index 99c0830270..82de77790e 100644 --- a/mindspore/lite/src/common/string_util.cc +++ b/mindspore/lite/src/common/string_util.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace lite { std::vector ParseTensorBuffer(Tensor *tensor) { - if (tensor->MutableData() == nullptr) { + if (tensor->data_c() == nullptr) { MS_LOG(ERROR) << "Tensor data is null, cannot be parsed"; return std::vector{}; } diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 649f0e1d12..d68200d6e3 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -65,8 +65,12 @@ int LiteSession::ConvertTensors(const lite::Model *model) { MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr"; } else { if (TensorCategory(srcTensor) == Tensor::Category::CONST) { - for (size_t j = 0; j < srcTensor->dims()->size(); j++) { - shape.push_back(srcTensor->dims()->data()[j]); + if (srcTensor->dataType() == kObjectTypeString && srcTensor->data() != nullptr) { + shape.push_back(srcTensor->data()->size()); + } else { + for (size_t j = 0; j < srcTensor->dims()->size(); j++) { + shape.push_back(srcTensor->dims()->data()[j]); + } } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc index 501571a3c6..596aa90c19 100644 --- a/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc +++ b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc @@ -48,11 +48,12 @@ int HashtableLookupCPUKernel::Run() { auto output_tensor = out_tensors_.at(0); auto hits_tensor = out_tensors_.at(1); - int rows = values_tensor->DimensionSize(0); + int rows = GetStringCount(values_tensor); int32_t *input_data = reinterpret_cast(input_tensor->MutableData()); uint8_t *hits_data = reinterpret_cast(hits_tensor->MutableData()); - std::vector output_string_pack; + std::vector output_string_pack(input_tensor->ElementsNum()); std::vector all_string_pack = ParseTensorBuffer(values_tensor); + lite::StringPack null_string_pack = {0, nullptr}; for (int i = 0; i < input_tensor->ElementsNum(); i++) { int index = -1; @@ -61,11 +62,10 @@ int HashtableLookupCPUKernel::Run() { index = reinterpret_cast(p) - reinterpret_cast(keys_tensor->MutableData()); } if (index >= rows || index < 0) { - lite::StringPack tmp = {0, nullptr}; - output_string_pack.push_back(tmp); + output_string_pack[i] = null_string_pack; hits_data[i] = 0; } else { - output_string_pack.push_back(all_string_pack[i]); + output_string_pack[i] = all_string_pack[i]; hits_data[i] = 1; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/string/predict.cc b/mindspore/lite/src/runtime/kernel/arm/string/predict.cc index fc552a075b..15c2fc9085 100644 --- a/mindspore/lite/src/runtime/kernel/arm/string/predict.cc +++ b/mindspore/lite/src/runtime/kernel/arm/string/predict.cc @@ -88,9 +88,10 @@ int PredictCPUKernel::Run() { if (static_cast(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) { output_label[i] = -1; output_weight[i] = 0.0f; + } else { + output_label[i] = label_info_vec[i].label; + output_weight[i] = label_info_vec[i].weight; } - output_label[i] = label_info_vec[i].label; - output_weight[i] = label_info_vec[i].weight; } return RET_OK; }