add force infershape for some op

pull/1224/head
wxl 4 years ago
parent 365401b52f
commit 5ae267433b

@ -272,7 +272,9 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
GE_CHECK_NOTNULL(op_desc);
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d",
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);
return SUCCESS;
}

@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) {
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR);
}
TEST_F(UtestGeHybrid, parse_force_infershape_nodes) {
const char *const kForceInfershape = "_force_infershape_when_running";
auto graph = make_shared<ComputeGraph>("graph");
OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D");
ge::AttrUtils::SetBool(op_desc, kForceInfershape, true);
auto node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> new_node;
NodeItem::Create(node, new_node);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS);
}
TEST_F(UtestGeHybrid, index_taskdefs_success) {
// build aicore task
domi::ModelTaskDef model_task_def;

Loading…
Cancel
Save