|
|
|
@ -17,8 +17,6 @@
|
|
|
|
|
#include "aicore_node_executor.h"
|
|
|
|
|
#include "cce/taskdown_common.hpp"
|
|
|
|
|
#include "hybrid/executor/hybrid_execution_context.h"
|
|
|
|
|
#include "init/gelib.h"
|
|
|
|
|
#include "hybrid/executor/hybrid_execution_context.h"
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
|
namespace hybrid {
|
|
|
|
@ -28,19 +26,10 @@ AiCoreNodeTask::AiCoreNodeTask(std::vector<std::unique_ptr<AiCoreOpTask>> &&task
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status AiCoreNodeExecutor::Initialize() {
|
|
|
|
|
auto ge_lib = GELib::GetInstance();
|
|
|
|
|
GE_CHECK_NOTNULL(ge_lib);
|
|
|
|
|
if (!ge_lib->InitFlag()) {
|
|
|
|
|
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed.");
|
|
|
|
|
return GE_CLI_GE_NOT_INITIALIZED;
|
|
|
|
|
compiler_ = TaskCompilerFactory::GetInstance().GetTaskCompiler();
|
|
|
|
|
if (compiler_ != nullptr) {
|
|
|
|
|
GE_CHK_STATUS_RET(compiler_->Initialize(), "Failed to init aicore task compiler.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &kernel_manager = ge_lib->OpsKernelManagerObj();
|
|
|
|
|
auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine");
|
|
|
|
|
GE_CHECK_NOTNULL(aic_ops_store);
|
|
|
|
|
|
|
|
|
|
compiler_.reset(new(std::nothrow)AiCoreTaskCompiler(aic_ops_store));
|
|
|
|
|
GE_CHECK_NOTNULL(compiler_);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -120,6 +109,12 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model,
|
|
|
|
|
GE_CHECK_NOTNULL(op_desc);
|
|
|
|
|
GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str());
|
|
|
|
|
|
|
|
|
|
auto ori_node_name = node->GetName();
|
|
|
|
|
if (compiler_ == nullptr) {
|
|
|
|
|
GELOGE(FAILED, "[%s] Can not find any valid aicore task compiler.", ori_node_name.c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance();
|
|
|
|
|
std::string shape_key;
|
|
|
|
|
GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str());
|
|
|
|
@ -133,7 +128,6 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<domi::TaskDef> task_defs;
|
|
|
|
|
auto ori_node_name = node->GetName();
|
|
|
|
|
op_desc->SetName(ori_node_name + "_" + shape_key);
|
|
|
|
|
GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str());
|
|
|
|
|
op_desc->SetName(ori_node_name);
|
|
|
|
@ -239,5 +233,23 @@ bool AiCoreNodeTask::IsNoOp(TaskContext &task_context) {
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TaskCompilerFactory &TaskCompilerFactory::GetInstance() {
|
|
|
|
|
static TaskCompilerFactory instance;
|
|
|
|
|
return instance;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TaskCompilerFactory::Register(CreateFn fn) {
|
|
|
|
|
compiler_func_ = fn;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<TaskCompiler> TaskCompilerFactory::GetTaskCompiler() {
|
|
|
|
|
auto compiler_instance = std::unique_ptr<TaskCompiler>(compiler_func_());
|
|
|
|
|
return compiler_instance;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CompilerFunctionRegistrar::CompilerFunctionRegistrar(CreateFn fn) {
|
|
|
|
|
TaskCompilerFactory::GetInstance().Register(fn);
|
|
|
|
|
}
|
|
|
|
|
} // namespace hybrid
|
|
|
|
|
} // namespace ge
|
|
|
|
|