From 7ce4fc36652a46c325fbc24803ed8172d8260745 Mon Sep 17 00:00:00 2001 From: wuweikang Date: Thu, 4 Mar 2021 09:23:10 +0800 Subject: [PATCH] multi-kernel modification --- ge/hybrid/model/hybrid_model_builder.cc | 7 ++- ge/single_op/single_op_model.cc | 9 +-- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 80 ++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 6 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 7ea9e446..48558e83 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1131,19 +1131,22 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const op_index = task_def.kernel_ex().op_index(); } else if (task_type == RT_MODEL_TASK_HCCL) { op_index = task_def.kernel_hccl().op_index(); + } else if (task_type == RT_MODEL_TASK_ALL_KERNEL) { + op_index = task_def.kernel_with_handle().context().op_index(); } else { GELOGD("Skip task type: %d", static_cast(task_type)); continue; } + GELOGD("op_index = %u, task_type = %d", op_index, task_type); auto iter = node_map.find(op_index); if (iter == node_map.end()) { - GELOGE(INTERNAL_ERROR, "Failed to get node by index = %u", op_index); + GELOGE(INTERNAL_ERROR, "Failed to get node by op_index = %u", op_index); return INTERNAL_ERROR; } auto &node = iter->second; - if (task_type == RT_MODEL_TASK_KERNEL) { + if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(node->GetOpDesc()); } diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 43c47894..49dde9c4 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -48,7 +48,8 @@ bool NeedHybridModel(GeModelPtr &ge_model) { auto tasks = ge_model->GetModelTaskDefPtr()->task(); int32_t kernel_task_num = 0; for (int i = 0; i < tasks.size(); ++i) { - if (static_cast(tasks[i].type()) == RT_MODEL_TASK_KERNEL) { + auto task_type = static_cast(tasks[i].type()); + if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { kernel_task_num++; if (kernel_task_num > 1) { return true; @@ -254,9 +255,9 @@ Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &s GELOGI("[%s] Task[%d], type = %u, DebugString = %s", model_name_.c_str(), i, task_def.type(), task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); - if (task_type == RT_MODEL_TASK_KERNEL) { - const domi::KernelDef &kernel_def = task_def.kernel(); - const auto &context = kernel_def.context(); + if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { + const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : + task_def.kernel_with_handle().context(); auto kernel_type = static_cast(context.kernel_type()); if (kernel_type == ccKernelType::TE) { GELOGD("Building TBE task"); diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 97a36894..0b6ca271 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -41,6 +41,7 @@ using namespace std; using namespace testing; using namespace ge; +using namespace hybrid; class UtestGeHybrid : public testing::Test { protected: @@ -110,4 +111,83 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { auto node = graph->AddNode(op_desc); optiling::OpRunInfo tiling_info; ASSERT_EQ(aicore_task->CalcTilingInfo(node, tiling_info), SUCCESS); +} + +TEST_F(UtestGeHybrid, index_taskdefs_failed) { + // build aicore task + domi::ModelTaskDef model_task_def; + + std::shared_ptr model_task_def_ptr = make_shared(model_task_def); + domi::TaskDef *task_def = model_task_def_ptr->add_task(); + GeModelPtr ge_model = make_shared(); + ge_model->SetModelTaskDef(model_task_def_ptr); + + auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); + task_def->set_type(RT_MODEL_TASK_ALL_KERNEL); + domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle(); + kernel_with_handle->set_original_kernel_key(""); + kernel_with_handle->set_node_info(""); + kernel_with_handle->set_block_dim(32); + kernel_with_handle->set_args_size(64); + string args(64, '1'); + kernel_with_handle->set_args(args.data(), 64); + domi::KernelContext *context = kernel_with_handle->mutable_context(); + context->set_op_index(1); + context->set_kernel_type(2); // ccKernelType::TE + uint16_t args_offset[9] = {0}; + context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); + + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + std::vector kernelBin; + TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); + op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); + std::string kernel_name("kernel/Add"); + AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); + + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = make_shared(graph); + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + + ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); +} + +TEST_F(UtestGeHybrid, index_taskdefs_success) { + // build aicore task + domi::ModelTaskDef model_task_def; + + std::shared_ptr model_task_def_ptr = make_shared(model_task_def); + domi::TaskDef *task_def = model_task_def_ptr->add_task(); + GeModelPtr ge_model = make_shared(); + ge_model->SetModelTaskDef(model_task_def_ptr); + + auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); + task_def->set_type(RT_MODEL_TASK_ALL_KERNEL); + domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle(); + kernel_with_handle->set_original_kernel_key(""); + kernel_with_handle->set_node_info(""); + kernel_with_handle->set_block_dim(32); + kernel_with_handle->set_args_size(64); + string args(64, '1'); + kernel_with_handle->set_args(args.data(), 64); + domi::KernelContext *context = kernel_with_handle->mutable_context(); + context->set_op_index(0); + context->set_kernel_type(2); // ccKernelType::TE + uint16_t args_offset[9] = {0}; + context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); + + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + std::vector kernelBin; + TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); + op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); + std::string kernel_name("kernel/Add"); + AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); + + ComputeGraphPtr graph = std::make_shared("test"); + NodePtr node = graph->AddNode(op_desc); + GeRootModelPtr ge_root_model = make_shared(graph); + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + + ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), SUCCESS); } \ No newline at end of file