diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index c7ef6c1a..95fb6749 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -36,6 +36,9 @@ #include "model/ge_model.h" #include "graph/shape_refiner.h" #include "graph/opsproto_manager.h" +#include "inc/pass_manager.h" +#include "graph/passes/net_output_pass.h" +#include "graph/passes/data_pass.h" using std::string; using namespace std; @@ -233,6 +236,7 @@ class Impl { ModelBufferData &ge_models); graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, bool is_dynamic_input); + static graphStatus InferShapePrepare(const ComputeGraphPtr &compute_graph); void SetRtSocVersion(); void UpdateThreadContext(); void LoadOpsProto(); @@ -243,6 +247,22 @@ class Impl { OmgContext omg_context_; }; +static graphStatus InferShapePrepare(const ComputeGraphPtr &compute_graph) { + GE_CHECK_NOTNULL(compute_graph); + + PassManager prepare_infershape; + prepare_infershape.AddPass("PrepareNetoutput", new(std::nothrow) NetOutputPass); + prepare_infershape.AddPass("PrepareSubGraphReflection", new (std::nothrow) DataPass); + + auto ret = prepare_infershape.Run(compute_graph); + if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { + GELOGE(ret, "Prepair for infershape failed, ret:%d", ret); + return ret; + } + GELOGD("Prepair for infershape success!"); + return GRAPH_SUCCESS; +} + graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { GELOGD("Enter Update Data Attr Process!"); if (options_.find(kInputShape) == options_.end()) { @@ -591,7 +611,12 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { return GRAPH_PARAM_INVALID; } - auto ret = compute_graph->TopologicalSorting(); + auto ret = Impl::InferShapePrepare(root_graph); + if (ret != GRAPH_SUCCESS) { + return ret; + } + + ret = compute_graph->TopologicalSorting(); if (ret != GRAPH_SUCCESS) { GELOGE(ret, "Acl topo logical sort failed."); return ret;