Modify usage of the benchmark parameter 'inputShapes'

pull/7717/head
wang_shaocong 4 years ago
parent acd156c084
commit daba04988d

@ -419,24 +419,8 @@ int Benchmark::RunBenchmark() {
std::cout << "CompileGraph failed while running ", model_name.c_str();
return ret;
}
if (!flags_->input_shape_list_.empty()) {
std::vector<std::vector<int>> input_shapes;
std::string input_dims_list = flags_->input_shape_list_;
while (!input_dims_list.empty()) {
auto position =
input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length();
std::string input_dims = input_dims_list.substr(0, position);
std::vector<int> input_shape;
while (!input_dims.empty()) {
auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length();
std::string dim = input_dims.substr(0, pos);
input_shape.emplace_back(std::stoi(dim));
input_dims = input_dims.substr(pos);
}
input_shapes.emplace_back(input_shape);
input_dims_list = input_dims_list.substr(position);
}
ret = session_->Resize(session_->GetInputs(), input_shapes);
if (!flags_->resize_dims_.empty()) {
ret = session_->Resize(session_->GetInputs(), flags_->resize_dims_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Input tensor resize failed.";
std::cout << "Input tensor resize failed.";
@ -496,7 +480,7 @@ void BenchmarkFlags::InitInputDataList() {
void BenchmarkFlags::InitResizeDimsList() {
std::string content;
content = this->resize_dims_in_;
std::vector<int64_t> shape;
std::vector<int> shape;
auto shape_strs = StringSplit(content, std::string(DELIM_COLON));
for (const auto &shape_str : shape_strs) {
shape.clear();
@ -504,7 +488,7 @@ void BenchmarkFlags::InitResizeDimsList() {
std::cout << "Resize Dims: ";
for (const auto &dim_str : dim_strs) {
std::cout << dim_str << " ";
shape.emplace_back(static_cast<int64_t>(std::stoi(dim_str)));
shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
}
std::cout << std::endl;
this->resize_dims_.emplace_back(shape);
@ -616,7 +600,8 @@ int Benchmark::Init() {
}
flags_->InitInputDataList();
flags_->InitResizeDimsList();
if (!flags_->resize_dims_.empty() && flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
return RET_ERROR;

@ -70,8 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType",
"Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT");
AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", "");
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
}
~BenchmarkFlags() override = default;
@ -88,7 +88,6 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
InDataType in_data_type_;
std::string in_data_type_in_ = "bin";
int cpu_bind_mode_ = 1;
std::string input_shape_list_;
// MarkPerformance
int loop_count_;
int num_threads_;
@ -101,7 +100,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
float accuracy_threshold_;
// Resize
std::string resize_dims_in_ = "";
std::vector<std::vector<int64_t>> resize_dims_;
std::vector<std::vector<int>> resize_dims_;
std::string device_;
};

Loading…
Cancel
Save