From 1f0268e87a6f50bc15e18f757a53c8883cc3f8c9 Mon Sep 17 00:00:00 2001 From: wxl Date: Wed, 9 Dec 2020 14:11:41 +0800 Subject: [PATCH] ir build optimize --- ge/ir_build/ge_ir_build.cc | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 34663493..6dfda036 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -50,6 +50,9 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default"; const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; + +const std::string kInputShape = "input_shape"; +const std::string kInputFormat = "input_format"; } // namespace static graphStatus CheckGlobalOptions(std::map &global_options) { @@ -232,6 +235,7 @@ class Impl { ModelBufferData &ge_models); graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, bool is_dynamic_input); + graphStatus UpdateDataOp(const Graph &graph); void SetRtSocVersion(); void UpdateThreadContext(); void LoadOpsProto(); @@ -242,6 +246,36 @@ class Impl { OmgContext omg_context_; }; +graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { + GELOGD("Enter Update Data Attr Process!"); + if (options_.find(kInputShape) == options_.end()) { + return GRAPH_SUCCESS; + } + unordered_map> shape_map; + vector>> user_shape_map; + GE_CHK_BOOL_EXEC(ParseInputShape(options_[kInputShape], shape_map, user_shape_map, true), + return GRAPH_PARAM_INVALID, "parse input shape failed!"); + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(input_node); + ge::OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == DATA) { + auto tensor = op->MutableInputDesc(0); + string data_op_name = op->GetName(); + auto iter = shape_map.find(data_op_name); + if (iter != shape_map.end()) { + tensor->SetShape(ge::GeShape(iter->second)); + GELOGD("update input [%s] shape info", data_op_name.c_str()); + } else { + GELOGI("no need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); + } + } + } + return GRAPH_SUCCESS; +} + graphStatus Impl::CheckOptions(const std::map &options) { for (auto &ele : options) { auto it = ge::ir_option::ir_builder_suppported_options.find(ele.first); @@ -437,7 +471,6 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vectorsecond); - tensor->SetShape(data_shape); GELOGD("Data op get shape from Context and update [%s] shape info", data_op_name.c_str()); } else { data_shape = tensor->GetShape();