modified: tests/ut/ge/hybrid/ge_hybrid_unittest.cc

pull/1204/head
zhaoxinxin 4 years ago
parent 5fe85f3f85
commit 3d0a83a455

@ -190,4 +190,34 @@ TEST_F(UtestGeHybrid, index_taskdefs_success) {
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), SUCCESS);
}
TEST_F(UtestGeHybrid, init_weight_success) {
// make graph with sub_graph
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("root_graph");
OpDescPtr op_desc = CreateOpDesc("if", IF);
/*std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));*/
//op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
//std::string kernel_name("kernel/Add");
//AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
NodePtr node = graph->AddNode(op_desc);
// make sub graph
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("if_sub_graph");
OpDescPtr const_op_desc = CreateOpDesc("const", CONSTANT);
vector<int64_t> dims_vec_0 = {2, 1, 4, 1, 2};
vector<int32_t> data_vec_0 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
GeTensorDesc tensor_desc_0(GeShape(dims_vec_0), FORMAT_NCHW, DT_INT32);
(void)TensorUtils::SetRealDimCnt(tensor_desc_0, dims_vec_0.size());
ConstGeTensorPtr constTensor_0 =
std::make_shared<GeTensor>(tensor_desc_0, (uint8_t *)&data_vec_0[0], data_vec_0.size() * sizeof(int32_t));
AttrUtils::SetTensor(const_op_desc, ge::ATTR_NAME_WEIGHTS, constTensor_0);
const_op_desc->AddOutputDesc(constTensor_0);
NodePtr const_node = sub_graph->AddNode(const_op_desc);
graph->AddSubgraph("sub", sub_graph);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
auto ret = hybrid_model_builder.InitWeights();
}
Loading…
Cancel
Save