|
|
|
@ -206,7 +206,10 @@ int Benchmark::ReadTensorData(std::ifstream &in_file_stream, const std::string &
|
|
|
|
|
std::string line;
|
|
|
|
|
getline(in_file_stream, line);
|
|
|
|
|
std::stringstream line_stream(line);
|
|
|
|
|
tensor::MSTensor *tensor = GetTensorByNodeOrTensorName(tensor_name);
|
|
|
|
|
if (this->benchmark_data_.find(tensor_name) != this->benchmark_data_.end()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
tensor::MSTensor *tensor = GetTensorByNameOrShape(tensor_name, dims);
|
|
|
|
|
if (tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -242,7 +245,7 @@ int Benchmark::CompareOutput() {
|
|
|
|
|
int total_size = 0;
|
|
|
|
|
for (const auto &calib_tensor : benchmark_data_) {
|
|
|
|
|
std::string node_or_tensor_name = calib_tensor.first;
|
|
|
|
|
tensor::MSTensor *tensor = GetTensorByNodeOrTensorName(node_or_tensor_name);
|
|
|
|
|
tensor::MSTensor *tensor = GetTensorByNameOrShape(node_or_tensor_name, calib_tensor.second->shape);
|
|
|
|
|
if (tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << node_or_tensor_name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -278,13 +281,35 @@ int Benchmark::CompareOutput() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tensor::MSTensor *Benchmark::GetTensorByNodeOrTensorName(const std::string &node_or_tensor_name) {
|
|
|
|
|
tensor::MSTensor *Benchmark::GetTensorByNodeShape(const std::vector<size_t> &node_shape) {
|
|
|
|
|
std::vector<tensor::MSTensor *> match_tensors;
|
|
|
|
|
std::vector<int> shape_vector;
|
|
|
|
|
(void)std::transform(node_shape.begin(), node_shape.end(), std::back_inserter(shape_vector),
|
|
|
|
|
[](const size_t &value) { return static_cast<int>(value); });
|
|
|
|
|
auto tensors = session_->GetOutputs();
|
|
|
|
|
for (auto &out_tensor_pair : tensors) {
|
|
|
|
|
if (out_tensor_pair.second->shape() == shape_vector) {
|
|
|
|
|
match_tensors.emplace_back(out_tensor_pair.second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (match_tensors.empty() || match_tensors.size() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "get tensor by node shape failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return match_tensors.front();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tensor::MSTensor *Benchmark::GetTensorByNameOrShape(const std::string &node_or_tensor_name,
|
|
|
|
|
const std::vector<size_t> &dims) {
|
|
|
|
|
tensor::MSTensor *tensor = nullptr;
|
|
|
|
|
auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name);
|
|
|
|
|
if (tensors.empty() || tensors.size() != 1) {
|
|
|
|
|
MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
|
|
|
|
|
<< " or node has more than one output tensor, switch to GetOutputByTensorName";
|
|
|
|
|
tensor = session_->GetOutputByTensorName(node_or_tensor_name);
|
|
|
|
|
if (tensor == nullptr) {
|
|
|
|
|
return GetTensorByNodeShape(dims);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
tensor = tensors.front();
|
|
|
|
|
}
|
|
|
|
|