diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 3d00ff7f..78a69392 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -601,6 +601,47 @@ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *pat return GRAPH_SUCCESS; } +graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { + auto compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + auto root_graph = compute_graph->GetParentGraph(); + if (root_graph != nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "Input param should not be subgraph"); + return GRAPH_PARAM_INVALID; + } + + auto ret = Impl::InferShapePrepare(compute_graph); + if (ret != GRAPH_SUCCESS) { + return ret; + } + + ret = compute_graph->TopologicalSorting(); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Acl topo logical sort failed."); + return ret; + } + + ret = compute_graph->InferOriginFormat(); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Acl InferOriginFormat failed."); + return ret; + } + + for (auto &node: compute_graph->GetAllNodes()) { + graphStatus ret = ShapeRefiner::InferShapeAndType(node); + if (ret == GRAPH_PARAM_INVALID) { + GELOGW("Can not find infershape func."); + continue; + } else if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Acl infershape failed."); + return ret; + } + } + + return GRAPH_SUCCESS; +} + graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len) { GE_CHECK_NOTNULL(file); diff --git a/inc/external/ge/ge_ir_build.h b/inc/external/ge/ge_ir_build.h index afaf42ac..182c0444 100644 --- a/inc/external/ge/ge_ir_build.h +++ b/inc/external/ge/ge_ir_build.h @@ -100,6 +100,16 @@ graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &mod */ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version); +/** + * @ingroup AscendCL + * @brief infer shape and data type + * + * @param graph[IN] the graph ready to build + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +graphStatus aclgrphInferShapeAndType(ge::Graph &graph); + /** * @ingroup AscendCL * @brief dump graph