From 54b6ce9eea0f78ea3bad270fa3711e7da2155381 Mon Sep 17 00:00:00 2001 From: l00444296 Date: Mon, 14 Dec 2020 21:30:42 +0800 Subject: [PATCH] Feature: Get default op format from ge graph --- ge/ir_build/ge_ir_build.cc | 110 +++++++++++++++++++++++++++++++------ 1 file changed, 93 insertions(+), 17 deletions(-) diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index f181170c..34e612a2 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -226,9 +226,11 @@ class Impl { }; ~Impl() { (void)generator_.Finalize(); }; graphStatus CheckOptions(const std::map &options); + graphStatus CheckInputFormat(const string &input_format); graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector &inputs); - graphStatus GetDefaultInputShape(const Graph &graph, string &default_shape); - graphStatus UpdateDataOpAttr(const Graph &graph); + graphStatus GetDefaultInputShape(const Graph &graph, string &default_shape, bool &dynamic_shape_flag); + graphStatus GetDefaultInputFormat(const Graph &graph, string &default_format); + const Graph &graph, string &default_shape, string &input_fo graphStatus UpdateDataOpAttr(const Graph &graph); graphStatus Init(const Graph &graph, const std::map &options); graphStatus BuildModel(const Graph &graph, const std::map &options, ModelBufferData &ge_models); @@ -321,7 +323,62 @@ graphStatus Impl::CheckOptions(const std::map &options return GRAPH_SUCCESS; } -graphStatus Impl::GetDefaultInputShape(const Graph &graph, string &default_shape) { +graphStatus Impl::CheckInputFormat(const string &input_format) { + if (!input_format.empty()) { + auto iter = ge::input_format_str_to_geformat.find(input_format); + if (iter == ge::input_format_str_to_geformat.end()) { + GELOGE(GRAPH_PARAM_INVALID, "Input format %s not support , expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.", + input_format.c_str()); + return GRAPH_PARAM_INVALID; + } + } + return GRAPH_SUCCESS; +} + +graphStatus Impl::GetDefaultInputFormat(const Graph &graph, string &default_format) { + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(input_node); + ge::OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == DATA) { + string data_op_name = op->GetName(); + GELOGD("Data op name: %s, data op inputDesc size: %zu", data_op_name.c_str(), op->GetAllInputsDesc().size()); + ge::GeTensorDesc tensor = op->GetInputDesc(0); + ge::GeShape data_shape = tensor.GetShape(); + GELOGD("Data op get shape from InputDesc in ge ir graph."); + + const std::vector &tmp_shape = data_shape.GetDims(); + if (tmp_shape.empty()) { + GELOGD("Data op: %s has zero shape dims!", data_op_name.c_str()); + continue; + } + + bool is_dynamic_input = false; + for (auto tmp_dim : tmp_shape) { + if (tmp_dim < 0) { + is_dynamic_input = true; + } + } + + if (is_dynamic_input) { + string tmp_data_format = ge::TypeUtils::FormatToSerialString(tensor.GetFormat()); + if (!default_format.empty() && tmp_data_format!=default_format) { + GELOGE(GRAPH_PARAM_INVALID, "All data op with dynamic shape has no default format!"); + return GRAPH_PARAM_INVALID; + } else if (default_format.empty()) { + default_format.assign(tmp_data_format); + } + GELOGD("Data op name: %s, data format: %s.", data_op_name.c_str(), default_format.c_str()); + } + } + } + GELOGI("Get default data op format: %s from ge ir graph.", default_format.c_str()); + return GRAPH_SUCCESS; +} + +graphStatus Impl::(const Graph &graph, string &default_shape, bool &dynamic_shape_flag) { auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { @@ -335,21 +392,30 @@ graphStatus Impl::GetDefaultInputShape(const Graph &graph, string &default_shape ge::GeShape data_shape = tensor.GetShape(); GELOGD("Data op get shape from InputDesc in ge ir graph."); - string tmp_shape_str; const std::vector &tmp_shape = data_shape.GetDims(); if (tmp_shape.empty()) { GELOGW("Data op: %s has zero shape dims!", data_op_name.c_str()); - } else { - tmp_shape_str += data_op_name + ":"; - for (auto tmp_dim : tmp_shape) { - tmp_shape_str += to_string((long)tmp_dim) + ","; + continue; + } + + string tmp_shape_str; + bool is_dynamic_input = false; + + tmp_shape_str += data_op_name + ":"; + for (auto tmp_dim : tmp_shape) { + if (tmp_dim < 0) { + is_dynamic_input = true; } - tmp_shape_str = tmp_shape_str.substr(0, tmp_shape_str.size() - 1); - tmp_shape_str += ";"; - default_shape += tmp_shape_str; + tmp_shape_str += to_string((long)tmp_dim) + ","; } + tmp_shape_str = tmp_shape_str.substr(0, tmp_shape_str.size() - 1); + tmp_shape_str += ";"; - GELOGD("Data op name: %s, data shape: %s.", data_op_name.c_str(), tmp_shape_str.c_str()); + if (is_dynamic_input) { + dynamic_shape_flag = true; + default_shape += tmp_shape_str; + GELOGD("Data op name: %s, data shape: %s.", data_op_name.c_str(), tmp_shape_str.c_str(),); + } } } default_shape = (default_shape.empty() ? default_shape : default_shape.substr(0, default_shape.size() - 1)); @@ -378,14 +444,24 @@ graphStatus Impl::Init(const Graph &graph, const std::map