diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index a370f32f5e..973220c96c 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -419,24 +419,8 @@ int Benchmark::RunBenchmark() { std::cout << "CompileGraph failed while running ", model_name.c_str(); return ret; } - if (!flags_->input_shape_list_.empty()) { - std::vector> input_shapes; - std::string input_dims_list = flags_->input_shape_list_; - while (!input_dims_list.empty()) { - auto position = - input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length(); - std::string input_dims = input_dims_list.substr(0, position); - std::vector input_shape; - while (!input_dims.empty()) { - auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length(); - std::string dim = input_dims.substr(0, pos); - input_shape.emplace_back(std::stoi(dim)); - input_dims = input_dims.substr(pos); - } - input_shapes.emplace_back(input_shape); - input_dims_list = input_dims_list.substr(position); - } - ret = session_->Resize(session_->GetInputs(), input_shapes); + if (!flags_->resize_dims_.empty()) { + ret = session_->Resize(session_->GetInputs(), flags_->resize_dims_); if (ret != RET_OK) { MS_LOG(ERROR) << "Input tensor resize failed."; std::cout << "Input tensor resize failed."; @@ -496,7 +480,7 @@ void BenchmarkFlags::InitInputDataList() { void BenchmarkFlags::InitResizeDimsList() { std::string content; content = this->resize_dims_in_; - std::vector shape; + std::vector shape; auto shape_strs = StringSplit(content, std::string(DELIM_COLON)); for (const auto &shape_str : shape_strs) { shape.clear(); @@ -504,7 +488,7 @@ void BenchmarkFlags::InitResizeDimsList() { std::cout << "Resize Dims: "; for (const auto &dim_str : dim_strs) { std::cout << dim_str << " "; - shape.emplace_back(static_cast(std::stoi(dim_str))); + shape.emplace_back(static_cast(std::stoi(dim_str))); } std::cout << std::endl; this->resize_dims_.emplace_back(shape); @@ -616,7 +600,8 @@ int Benchmark::Init() { } flags_->InitInputDataList(); flags_->InitResizeDimsList(); - if (!flags_->resize_dims_.empty() && flags_->resize_dims_.size() != flags_->input_data_list_.size()) { + if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() && + flags_->resize_dims_.size() != flags_->input_data_list_.size()) { MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl; return RET_ERROR; diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h index 198cc38491..5795e9ed3c 100644 --- a/mindspore/lite/tools/benchmark/benchmark.h +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -70,8 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser { AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType", "Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT"); AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); - AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes", - "Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", ""); + AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes", + "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); } ~BenchmarkFlags() override = default; @@ -88,7 +88,6 @@ class MS_API BenchmarkFlags : public virtual FlagParser { InDataType in_data_type_; std::string in_data_type_in_ = "bin"; int cpu_bind_mode_ = 1; - std::string input_shape_list_; // MarkPerformance int loop_count_; int num_threads_; @@ -101,7 +100,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { float accuracy_threshold_; // Resize std::string resize_dims_in_ = ""; - std::vector> resize_dims_; + std::vector> resize_dims_; std::string device_; };