!1189 multi-kernel modification

From: @HW_KK
Reviewed-by: @xchu42,@ji_chen
Signed-off-by:
pull/1189/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 70dd14f975

@ -1131,19 +1131,22 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const
op_index = task_def.kernel_ex().op_index();
} else if (task_type == RT_MODEL_TASK_HCCL) {
op_index = task_def.kernel_hccl().op_index();
} else if (task_type == RT_MODEL_TASK_ALL_KERNEL) {
op_index = task_def.kernel_with_handle().context().op_index();
} else {
GELOGD("Skip task type: %d", static_cast<int>(task_type));
continue;
}
GELOGD("op_index = %u, task_type = %d", op_index, task_type);
auto iter = node_map.find(op_index);
if (iter == node_map.end()) {
GELOGE(INTERNAL_ERROR, "Failed to get node by index = %u", op_index);
GELOGE(INTERNAL_ERROR, "Failed to get node by op_index = %u", op_index);
return INTERNAL_ERROR;
}
auto &node = iter->second;
if (task_type == RT_MODEL_TASK_KERNEL) {
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(node->GetOpDesc());
}

@ -48,7 +48,8 @@ bool NeedHybridModel(GeModelPtr &ge_model) {
auto tasks = ge_model->GetModelTaskDefPtr()->task();
int32_t kernel_task_num = 0;
for (int i = 0; i < tasks.size(); ++i) {
if (static_cast<rtModelTaskType_t>(tasks[i].type()) == RT_MODEL_TASK_KERNEL) {
auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type());
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
kernel_task_num++;
if (kernel_task_num > 1) {
return true;
@ -254,9 +255,9 @@ Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &s
GELOGI("[%s] Task[%d], type = %u, DebugString = %s", model_name_.c_str(), i, task_def.type(),
task_def.DebugString().c_str());
auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
if (task_type == RT_MODEL_TASK_KERNEL) {
const domi::KernelDef &kernel_def = task_def.kernel();
const auto &context = kernel_def.context();
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() :
task_def.kernel_with_handle().context();
auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
GELOGD("Building TBE task");

@ -41,6 +41,7 @@
using namespace std;
using namespace testing;
using namespace ge;
using namespace hybrid;
class UtestGeHybrid : public testing::Test {
protected:
@ -110,4 +111,83 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) {
auto node = graph->AddNode(op_desc);
optiling::OpRunInfo tiling_info;
ASSERT_EQ(aicore_task->CalcTilingInfo(node, tiling_info), SUCCESS);
}
TEST_F(UtestGeHybrid, index_taskdefs_failed) {
// build aicore task
domi::ModelTaskDef model_task_def;
std::shared_ptr<domi::ModelTaskDef> model_task_def_ptr = make_shared<domi::ModelTaskDef>(model_task_def);
domi::TaskDef *task_def = model_task_def_ptr->add_task();
GeModelPtr ge_model = make_shared<GeModel>();
ge_model->SetModelTaskDef(model_task_def_ptr);
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
task_def->set_type(RT_MODEL_TASK_ALL_KERNEL);
domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle();
kernel_with_handle->set_original_kernel_key("");
kernel_with_handle->set_node_info("");
kernel_with_handle->set_block_dim(32);
kernel_with_handle->set_args_size(64);
string args(64, '1');
kernel_with_handle->set_args(args.data(), 64);
domi::KernelContext *context = kernel_with_handle->mutable_context();
context->set_op_index(1);
context->set_kernel_type(2); // ccKernelType::TE
uint16_t args_offset[9] = {0};
context->set_args_offset(args_offset, 9 * sizeof(uint16_t));
OpDescPtr op_desc = CreateOpDesc("Add", "Add");
std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
std::string kernel_name("kernel/Add");
AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR);
}
TEST_F(UtestGeHybrid, index_taskdefs_success) {
// build aicore task
domi::ModelTaskDef model_task_def;
std::shared_ptr<domi::ModelTaskDef> model_task_def_ptr = make_shared<domi::ModelTaskDef>(model_task_def);
domi::TaskDef *task_def = model_task_def_ptr->add_task();
GeModelPtr ge_model = make_shared<GeModel>();
ge_model->SetModelTaskDef(model_task_def_ptr);
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
task_def->set_type(RT_MODEL_TASK_ALL_KERNEL);
domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle();
kernel_with_handle->set_original_kernel_key("");
kernel_with_handle->set_node_info("");
kernel_with_handle->set_block_dim(32);
kernel_with_handle->set_args_size(64);
string args(64, '1');
kernel_with_handle->set_args(args.data(), 64);
domi::KernelContext *context = kernel_with_handle->mutable_context();
context->set_op_index(0);
context->set_kernel_type(2); // ccKernelType::TE
uint16_t args_offset[9] = {0};
context->set_args_offset(args_offset, 9 * sizeof(uint16_t));
OpDescPtr op_desc = CreateOpDesc("Add", "Add");
std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
std::string kernel_name("kernel/Add");
AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = graph->AddNode(op_desc);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), SUCCESS);
}
Loading…
Cancel
Save