fix cannot find output bug

pull/11848/head
xuanyue 4 years ago
parent 072f6025cc
commit 48dabf90bc

@ -206,7 +206,10 @@ int Benchmark::ReadTensorData(std::ifstream &in_file_stream, const std::string &
std::string line;
getline(in_file_stream, line);
std::stringstream line_stream(line);
tensor::MSTensor *tensor = GetTensorByNodeOrTensorName(tensor_name);
if (this->benchmark_data_.find(tensor_name) != this->benchmark_data_.end()) {
return RET_OK;
}
tensor::MSTensor *tensor = GetTensorByNameOrShape(tensor_name, dims);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
return RET_ERROR;
@ -242,7 +245,7 @@ int Benchmark::CompareOutput() {
int total_size = 0;
for (const auto &calib_tensor : benchmark_data_) {
std::string node_or_tensor_name = calib_tensor.first;
tensor::MSTensor *tensor = GetTensorByNodeOrTensorName(node_or_tensor_name);
tensor::MSTensor *tensor = GetTensorByNameOrShape(node_or_tensor_name, calib_tensor.second->shape);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << node_or_tensor_name;
return RET_ERROR;
@ -278,13 +281,35 @@ int Benchmark::CompareOutput() {
return RET_OK;
}
tensor::MSTensor *Benchmark::GetTensorByNodeOrTensorName(const std::string &node_or_tensor_name) {
tensor::MSTensor *Benchmark::GetTensorByNodeShape(const std::vector<size_t> &node_shape) {
std::vector<tensor::MSTensor *> match_tensors;
std::vector<int> shape_vector;
(void)std::transform(node_shape.begin(), node_shape.end(), std::back_inserter(shape_vector),
[](const size_t &value) { return static_cast<int>(value); });
auto tensors = session_->GetOutputs();
for (auto &out_tensor_pair : tensors) {
if (out_tensor_pair.second->shape() == shape_vector) {
match_tensors.emplace_back(out_tensor_pair.second);
}
}
if (match_tensors.empty() || match_tensors.size() != 1) {
MS_LOG(ERROR) << "get tensor by node shape failed";
return nullptr;
}
return match_tensors.front();
}
tensor::MSTensor *Benchmark::GetTensorByNameOrShape(const std::string &node_or_tensor_name,
const std::vector<size_t> &dims) {
tensor::MSTensor *tensor = nullptr;
auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name);
if (tensors.empty() || tensors.size() != 1) {
MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
<< " or node has more than one output tensor, switch to GetOutputByTensorName";
tensor = session_->GetOutputByTensorName(node_or_tensor_name);
if (tensor == nullptr) {
return GetTensorByNodeShape(dims);
}
} else {
tensor = tensors.front();
}

@ -135,7 +135,9 @@ class MS_API Benchmark {
int CompareOutput();
tensor::MSTensor *GetTensorByNodeOrTensorName(const std::string &node_or_tensor_name);
tensor::MSTensor *GetTensorByNameOrShape(const std::string &node_or_tensor_name, const std::vector<size_t> &dims);
tensor::MSTensor *GetTensorByNodeShape(const std::vector<size_t> &node_shape);
int CompareStringData(const std::string &name, tensor::MSTensor *tensor);

Loading…
Cancel
Save