Support assigned input shapes while run benchmark tool.

pull/7260/head
wang_shaocong 4 years ago
parent 8c329605d2
commit db63e4e5f6

@ -406,6 +406,30 @@ int Benchmark::RunBenchmark() {
std::cout << "CompileGraph failed while running ", model_name.c_str();
return ret;
}
if (!flags_->input_shape_list_.empty()) {
std::vector<std::vector<int>> input_shapes;
std::string input_dims_list = flags_->input_shape_list_;
while (!input_dims_list.empty()) {
auto position =
input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length();
std::string input_dims = input_dims_list.substr(0, position);
std::vector<int> input_shape;
while (!input_dims.empty()) {
auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length();
std::string dim = input_dims.substr(0, pos);
input_shape.emplace_back(std::stoi(dim));
input_dims = input_dims.substr(pos);
}
input_shapes.emplace_back(input_shape);
input_dims_list = input_dims_list.substr(position);
}
ret = session_->Resize(session_->GetInputs(), input_shapes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Input tensor resize failed.";
std::cout << "Input tensor resize failed.";
return ret;
}
}
model->Free();
ms_inputs_ = session_->GetInputs();
auto end_prepare_time = GetTimeUs();

@ -70,6 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType",
"Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT");
AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", "");
}
~BenchmarkFlags() override = default;
@ -86,6 +88,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
InDataType in_data_type_;
std::string in_data_type_in_ = "bin";
int cpu_bind_mode_ = 1;
std::string input_shape_list_;
// MarkPerformance
int loop_count_;
int num_threads_;

@ -26,6 +26,9 @@ using mindspore::lite::Tensor;
namespace mindspore {
namespace lite {
namespace {
constexpr int DEFAULT_DIM_VALUE = -1;
}
namespace {
std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs,
const schema::PrimitiveType node_type) {
std::vector<Tensor *> lite_tensors;
@ -85,6 +88,15 @@ void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> outp
} // namespace
STATUS InferShapePass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto idx : graph->inputIndex) {
auto input_tensor = graph->allTensors[idx].get();
for (auto &dim : input_tensor->dims) {
if (dim == 0) {
MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to 32 as a default value.";
dim = DEFAULT_DIM_VALUE;
}
}
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type);

@ -41,9 +41,16 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "value") {
if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) {
auto tensor = onnx_node_attr.t();
if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) {
attr->value = onnx_node_attr.f();
} else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) {
attr->value = static_cast<int32_t>(onnx_node_attr.i());
}
}
}
}
op->primitive->value.type = schema::PrimitiveType_ConstantOfShape;
op->primitive->value.value = attr.release();

@ -66,14 +66,14 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "kernel_shape") {
if (onnx_node_attr.ints_size() == 2) {
attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(1));
}
}
if (attribute_name == "strides") {
if (onnx_node_attr.ints_size() == 2) {
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
}
}
if (attribute_name == "auto_pad") {

Loading…
Cancel
Save