From 5ae267433be2f99134d5fe26f6b6adbcb37f71ba Mon Sep 17 00:00:00 2001 From: wxl Date: Tue, 9 Mar 2021 22:36:32 +0800 Subject: [PATCH] add force infershape for some op --- ge/hybrid/model/hybrid_model_builder.cc | 4 +++- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 58a7c23f..a349210d 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -272,7 +272,9 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt GE_CHECK_NOTNULL(op_desc); // not care result, if no this attr, stand for the op does not need force infershape (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); - GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape); + GELOGD("node [%s] is need do infershape , flag is %d", + op_desc->GetName().c_str(), + node_item.is_need_force_infershape); return SUCCESS; } diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 0b6ca271..286186de 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); } +TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { + const char *const kForceInfershape = "_force_infershape_when_running"; + auto graph = make_shared("graph"); + OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D"); + ge::AttrUtils::SetBool(op_desc, kForceInfershape, true); + auto node = graph->AddNode(op_desc); + std::unique_ptr new_node; + NodeItem::Create(node, new_node); + GeRootModelPtr ge_root_model = make_shared(graph); + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); +} + TEST_F(UtestGeHybrid, index_taskdefs_success) { // build aicore task domi::ModelTaskDef model_task_def;