From b3bd33546fd43613d46f31e9336cf4180800f662 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Thu, 7 Jan 2021 09:30:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!886=20:?= =?UTF-8?q?=20remove=20interface=20aclgrphInfershapeAndType'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/ir_build/ge_ir_build.cc | 41 +++++++++++++++++++++++++++++++++++ inc/external/ge/ge_ir_build.h | 10 +++++++++ 2 files changed, 51 insertions(+) diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 3d00ff7f..78a69392 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -601,6 +601,47 @@ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *pat return GRAPH_SUCCESS; } +graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { + auto compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + auto root_graph = compute_graph->GetParentGraph(); + if (root_graph != nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "Input param should not be subgraph"); + return GRAPH_PARAM_INVALID; + } + + auto ret = Impl::InferShapePrepare(compute_graph); + if (ret != GRAPH_SUCCESS) { + return ret; + } + + ret = compute_graph->TopologicalSorting(); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Acl topo logical sort failed."); + return ret; + } + + ret = compute_graph->InferOriginFormat(); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Acl InferOriginFormat failed."); + return ret; + } + + for (auto &node: compute_graph->GetAllNodes()) { + graphStatus ret = ShapeRefiner::InferShapeAndType(node); + if (ret == GRAPH_PARAM_INVALID) { + GELOGW("Can not find infershape func."); + continue; + } else if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Acl infershape failed."); + return ret; + } + } + + return GRAPH_SUCCESS; +} + graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len) { GE_CHECK_NOTNULL(file); diff --git a/inc/external/ge/ge_ir_build.h b/inc/external/ge/ge_ir_build.h index afaf42ac..182c0444 100644 --- a/inc/external/ge/ge_ir_build.h +++ b/inc/external/ge/ge_ir_build.h @@ -100,6 +100,16 @@ graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &mod */ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version); +/** + * @ingroup AscendCL + * @brief infer shape and data type + * + * @param graph[IN] the graph ready to build + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +graphStatus aclgrphInferShapeAndType(ge::Graph &graph); + /** * @ingroup AscendCL * @brief dump graph