modified: ../../tests/ut/ge/hybrid/ge_hybrid_unittest.cc

pull/1204/head
zhaoxinxin 4 years ago
parent 801a1e0fca
commit 56ff720fac

@ -15,8 +15,8 @@
*/
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
#include "runtime/rt.h"
#define protected public
@ -25,7 +25,6 @@
#include "hybrid/model/hybrid_model.h"
#include "model/ge_model.h"
#include "model/ge_root_model.h"
#include "hybrid/node_executor/aicore/aicore_op_task.h"
#include "framework/common/taskdown_common.h"
#include "framework/common/debug/log.h"
@ -33,6 +32,8 @@
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/node_executor/aicore/aicore_task_builder.h"
#include "graph/load/model_manager/tbe_handle_store.h"
#include "graph/manager/graph_mem_allocator.h"
#include "hybrid/common/npu_memory_allocator.h"
#include "graph/types.h"
#include "graph/utils/tensor_utils.h"
@ -44,6 +45,7 @@ using namespace testing;
using namespace ge;
using namespace hybrid;
class UtestGeHybrid : public testing::Test {
protected:
void SetUp() {}
@ -194,14 +196,10 @@ TEST_F(UtestGeHybrid, index_taskdefs_success) {
}
TEST_F(UtestGeHybrid, init_weight_success) {
NpuMemoryAllocator::allocators_.emplace(make_pair(0, nullptr));
// make graph with sub_graph
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("root_graph");
OpDescPtr op_desc = CreateOpDesc("if", IF);
/*std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));*/
//op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
//std::string kernel_name("kernel/Add");
//AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
NodePtr node = graph->AddNode(op_desc);
// make sub graph
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("if_sub_graph");
@ -218,9 +216,16 @@ TEST_F(UtestGeHybrid, init_weight_success) {
graph->AddSubgraph("sub", sub_graph);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
GeModelPtr ge_sub_model = make_shared<GeModelPtr>(sub_graph);
GeModelPtr ge_sub_model = make_shared<GeModel>();
//Buffer weight_buffer = Buffer(128,0);
//ge_sub_model->SetWeight(weight_buffer);
ge_root_model->SetSubgraphInstanceNameToModel("sub",ge_sub_model);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
auto ret = hybrid_model_builder.InitWeights();
ASSERT_EQ(ret,SUCCESS);
Buffer weight_buffer = Buffer(128,0);
ge_sub_model->SetWeight(weight_buffer);
ret = hybrid_model_builder.InitWeights();
ASSERT_EQ(ret,PARAM_INVALID);
}
Loading…
Cancel
Save