diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 4c32f131..dbd784c6 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -17,8 +17,6 @@ #include "aicore_node_executor.h" #include "cce/taskdown_common.hpp" #include "hybrid/executor/hybrid_execution_context.h" -#include "init/gelib.h" -#include "hybrid/executor/hybrid_execution_context.h" namespace ge { namespace hybrid { @@ -28,19 +26,10 @@ AiCoreNodeTask::AiCoreNodeTask(std::vector> &&task } Status AiCoreNodeExecutor::Initialize() { - auto ge_lib = GELib::GetInstance(); - GE_CHECK_NOTNULL(ge_lib); - if (!ge_lib->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); - return GE_CLI_GE_NOT_INITIALIZED; + compiler_ = TaskCompilerFactory::GetInstance().GetTaskCompiler(); + if (compiler_ != nullptr) { + GE_CHK_STATUS_RET(compiler_->Initialize(), "Failed to init aicore task compiler."); } - - auto &kernel_manager = ge_lib->OpsKernelManagerObj(); - auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); - GE_CHECK_NOTNULL(aic_ops_store); - - compiler_.reset(new(std::nothrow)AiCoreTaskCompiler(aic_ops_store)); - GE_CHECK_NOTNULL(compiler_); return SUCCESS; } @@ -120,6 +109,12 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, GE_CHECK_NOTNULL(op_desc); GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); + auto ori_node_name = node->GetName(); + if (compiler_ == nullptr) { + GELOGE(FAILED, "[%s] Can not find any valid aicore task compiler.", ori_node_name.c_str()); + return FAILED; + } + AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); std::string shape_key; GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); @@ -133,7 +128,6 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, } std::vector task_defs; - auto ori_node_name = node->GetName(); op_desc->SetName(ori_node_name + "_" + shape_key); GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str()); op_desc->SetName(ori_node_name); @@ -239,5 +233,23 @@ bool AiCoreNodeTask::IsNoOp(TaskContext &task_context) { return true; } + +TaskCompilerFactory &TaskCompilerFactory::GetInstance() { + static TaskCompilerFactory instance; + return instance; +} + +void TaskCompilerFactory::Register(CreateFn fn) { + compiler_func_ = fn; +} + +std::unique_ptr TaskCompilerFactory::GetTaskCompiler() { + auto compiler_instance = std::unique_ptr(compiler_func_()); + return compiler_instance; +} + +CompilerFunctionRegistrar::CompilerFunctionRegistrar(CreateFn fn) { + TaskCompilerFactory::GetInstance().Register(fn); +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/ge/hybrid/node_executor/aicore/aicore_node_executor.h index 374782dc..989090e9 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.h +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.h @@ -18,13 +18,21 @@ #define GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ #include "hybrid/node_executor/aicore/aicore_task_builder.h" -#include "hybrid/node_executor/aicore/aicore_task_compiler.h" #include "hybrid/node_executor/node_executor.h" #include #include namespace ge { namespace hybrid { + +class TaskCompiler { + public: + TaskCompiler() = default; + virtual ~TaskCompiler() = default; + virtual Status CompileOp(const NodePtr &node, std::vector &tasks) = 0; + virtual Status Initialize() = 0; +}; + class AiCoreNodeTaskRegistry { public: ~AiCoreNodeTaskRegistry() = default; @@ -65,8 +73,33 @@ class AiCoreNodeExecutor : public NodeExecutor { private: static Status GenNodeKey(const NodePtr &node, std::string &node_key); - std::unique_ptr compiler_; + std::unique_ptr compiler_; +}; + +using CreateFn = TaskCompiler *(*)(); +class TaskCompilerFactory { + public: + static TaskCompilerFactory &GetInstance(); + void Register(CreateFn fn); + std::unique_ptr GetTaskCompiler(); + + private: + CreateFn compiler_func_; +}; + +class CompilerFunctionRegistrar { + public: + CompilerFunctionRegistrar(CreateFn fn); + ~CompilerFunctionRegistrar() = default; }; } // namespace hybrid } // namespace ge -#endif //GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ + +#define REGISTER_TASK_COMPILER(compiler) \ + static ::ge::hybrid::CompilerFunctionRegistrar register_compiler_function \ + __attribute__((unused)) = \ + ::ge::hybrid::CompilerFunctionRegistrar([]()->::ge::hybrid::TaskCompiler* { \ + return new (std::nothrow) compiler(); \ + }) \ + +#endif //GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc index ed92ada7..26a41737 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc +++ b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc @@ -18,6 +18,7 @@ #include "framework/common/debug/log.h" #include "graph/debug/ge_attr_define.h" #include "opskernel_manager/ops_kernel_builder_manager.h" +#include "init/gelib.h" namespace ge { namespace hybrid { @@ -25,11 +26,22 @@ namespace { uintptr_t kWeightBase = 0x10000000; uintptr_t kMemBase = 0x20000000; uint64_t kFakeSize = 0x10000000UL; +REGISTER_TASK_COMPILER(AiCoreTaskCompiler); } std::mutex AiCoreTaskCompiler::mu_; -AiCoreTaskCompiler::AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store) - : aic_kernel_store_(std::move(aic_kernel_store)) {} +Status AiCoreTaskCompiler::Initialize() { + auto ge_lib = GELib::GetInstance(); + GE_CHECK_NOTNULL(ge_lib); + if (!ge_lib->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); + return GE_CLI_GE_NOT_INITIALIZED; + } + auto &kernel_manager = ge_lib->OpsKernelManagerObj(); + aic_kernel_store_ = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); + GE_CHECK_NOTNULL(aic_kernel_store_); + return SUCCESS; +} Status AiCoreTaskCompiler::DoCompileOp(const NodePtr &node) const { GE_CHECK_NOTNULL(node); diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.h b/ge/hybrid/node_executor/aicore/aicore_task_compiler.h index 38ed458f..bf948349 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_compiler.h +++ b/ge/hybrid/node_executor/aicore/aicore_task_compiler.h @@ -19,15 +19,17 @@ #include #include "opskernel_manager/ops_kernel_manager.h" +#include "aicore_node_executor.h" namespace ge { namespace hybrid { -class AiCoreTaskCompiler { +class AiCoreTaskCompiler : public TaskCompiler { public: - explicit AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store); + AiCoreTaskCompiler() = default; ~AiCoreTaskCompiler() = default; - Status CompileOp(const NodePtr &node, std::vector &tasks); + Status CompileOp(const NodePtr &node, std::vector &tasks) override; + Status Initialize() override; private: Status DoCompileOp(const NodePtr &node) const; Status DoGenerateTask(const Node &node, std::vector &tasks);