|
|
|
@ -19,10 +19,12 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "runtime/rt.h"
|
|
|
|
|
|
|
|
|
|
#include "graph/utils/node_utils.h"
|
|
|
|
|
#define protected public
|
|
|
|
|
#define private public
|
|
|
|
|
#include "hybrid/model/hybrid_model_builder.h"
|
|
|
|
|
#include "hybrid/model/hybrid_model.h"
|
|
|
|
|
#include "hybrid/node_executor/node_executor.h"
|
|
|
|
|
#include "model/ge_model.h"
|
|
|
|
|
#include "model/ge_root_model.h"
|
|
|
|
|
#include "hybrid/node_executor/aicore/aicore_op_task.h"
|
|
|
|
@ -247,7 +249,7 @@ TEST_F(UtestGeHybrid, init_weight_success) {
|
|
|
|
|
ASSERT_EQ(ret,PARAM_INVALID);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(UtestGeHybrid, hybrid_model_executor) {
|
|
|
|
|
TEST_F(UtestGeHybrid, hybrid_model_executor) {
|
|
|
|
|
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc");
|
|
|
|
|
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
|
|
|
|
|
HybridModel model(root_model);
|
|
|
|
@ -258,3 +260,71 @@ TEST_F(UtestGeHybrid, init_weight_success) {
|
|
|
|
|
HybridModelExecutor executor(model_ptr, device_id, stream);
|
|
|
|
|
executor.Init();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(UtestGeHybrid, test_parse_parallel_group) {
|
|
|
|
|
NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl",
|
|
|
|
|
NodeExecutorManager::ExecutorType::HCCL);
|
|
|
|
|
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test");
|
|
|
|
|
OpDescPtr op_desc = CreateOpDesc("AllReduce", "AllReduce");
|
|
|
|
|
op_desc->SetId(0);
|
|
|
|
|
ge::AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, "group_1");
|
|
|
|
|
auto node = compute_graph->AddNode(op_desc);
|
|
|
|
|
std::unique_ptr<NodeItem> node_item;
|
|
|
|
|
NodeItem::Create(node, node_item);
|
|
|
|
|
node_item->node_id = 0;
|
|
|
|
|
|
|
|
|
|
op_desc->SetOpKernelLibName("ops_kernel_info_hccl");
|
|
|
|
|
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
|
|
|
|
|
HybridModel model(root_model);
|
|
|
|
|
|
|
|
|
|
HybridModelBuilder builder(model);
|
|
|
|
|
builder.root_graph_ = compute_graph;
|
|
|
|
|
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS);
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
|
|
|
|
|
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);
|
|
|
|
|
|
|
|
|
|
OpDescPtr op_desc_1 = CreateOpDesc("subgraph", "PartitionedCall");
|
|
|
|
|
op_desc_1->AddSubgraphName("subgraph");
|
|
|
|
|
auto node_1 = compute_graph->AddNode(op_desc_1);
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr subgraph = MakeShared<ComputeGraph>("subgraph");
|
|
|
|
|
ASSERT_EQ(NodeUtils::SetSubgraph(*node_1, 0, subgraph), GRAPH_SUCCESS);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<NodeItem> node_item_1;
|
|
|
|
|
NodeItem::Create(node_1, node_item_1);
|
|
|
|
|
node_item_1->node_id = 1;
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
|
|
|
|
|
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
|
|
|
|
|
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);
|
|
|
|
|
|
|
|
|
|
OpDescPtr op_desc_2 = CreateOpDesc("sub_node_1", "AllReduce");
|
|
|
|
|
ge::AttrUtils::SetStr(op_desc_2, ATTR_NAME_PARALLEL_GROUP, "group_1");
|
|
|
|
|
auto node_2 = subgraph->AddNode(op_desc_2);
|
|
|
|
|
ASSERT_TRUE(node_2 != nullptr);
|
|
|
|
|
|
|
|
|
|
OpDescPtr op_desc_3 = CreateOpDesc("sub_node_2", "AllReduce2");
|
|
|
|
|
ge::AttrUtils::SetStr(op_desc_3, ATTR_NAME_PARALLEL_GROUP, "group_2");
|
|
|
|
|
auto node_3 = subgraph->AddNode(op_desc_3);
|
|
|
|
|
ASSERT_TRUE(node_3 != nullptr);
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
|
|
|
|
|
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 2);
|
|
|
|
|
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 2);
|
|
|
|
|
ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2);
|
|
|
|
|
ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1);
|
|
|
|
|
|
|
|
|
|
ASSERT_FALSE(node_item->has_observer);
|
|
|
|
|
ASSERT_TRUE(node_item_1->dependents_for_execution.empty());
|
|
|
|
|
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
|
|
|
|
|
ASSERT_TRUE(node_item->has_observer);
|
|
|
|
|
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
|
|
|
|
|
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);
|
|
|
|
|
|
|
|
|
|
// repeat parse
|
|
|
|
|
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
|
|
|
|
|
ASSERT_TRUE(node_item->has_observer);
|
|
|
|
|
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
|
|
|
|
|
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);
|
|
|
|
|
}
|