/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ #define GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ #include "external/ge/ge_api_error_codes.h" #include "common/opskernel/ops_kernel_builder.h" #include "graph/node.h" #include "task_context.h" namespace ge { const uint32_t MEMORY_ALIGN_RATIO = 2; const uint32_t MEMORY_ALIGN_SIZE = 32; namespace hybrid { class HybridModel; // Base class of Node Task class NodeTask { public: NodeTask() = default; virtual ~NodeTask() = default; /** * Update tiling data * @param context instance of TaskContext * @return SUCCESS on success, error code otherwise */ virtual Status UpdateTilingData(TaskContext &context) { return SUCCESS; } /** * Init * @param context instance of TaskContext * @return SUCCESS on success, error code otherwise */ virtual Status Init(TaskContext &context) { return SUCCESS; } /** * Whether this task supports dynamic shape * @return true if this task supports dynamic shape, false otherwise */ virtual bool IsSupportDynamicShape() { return true; } /** * Update args for execution * @param context instance of TaskContext * @return SUCCESS on success, error code otherwise */ virtual Status UpdateArgs(TaskContext &context) = 0; /** * Execute task async * @param context instance of TaskContext * @param done_callback callback function, will be invoked after task is done * @return SUCCESS on success, error code otherwise */ virtual Status ExecuteAsync(TaskContext &context, std::function done_callback) = 0; }; // Node executor class NodeExecutor { public: NodeExecutor() = default; virtual ~NodeExecutor() = default; /** * Initialize node executor * @return SUCCESS on success, error code otherwise */ virtual Status Initialize() { return SUCCESS; } /** * Finalize node executor * @return SUCCESS on success, error code otherwise */ virtual Status Finalize() { return SUCCESS; } /** * Load task in load stage * @param model instance of HybridModel * @param node node * @param task generated node task * @return SUCCESS on success, error code otherwise */ virtual Status LoadTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; /** * Compile task in run stage * @param model instance of HybridModel * @param node node * @param task generated node task * @return SUCCESS on success, error code otherwise */ virtual Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; /** * Preparation actions before execution * @param task instance of NodeTask * @param context instance of TaskContext * @return SUCCESS on success, error code otherwise */ virtual Status PrepareTask(NodeTask &task, TaskContext &context) const; /** * Execute task * @param task instance of NodeTask * @param context instance of TaskContext * @param callback callback function which will be invoked after computation is done * @return SUCCESS on success, error code otherwise */ virtual Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; }; class NodeExecutorManager { public: enum class ExecutorType { AICORE, AICPU_TF, AICPU_CUSTOM, COMPILED_SUBGRAPH, DYNAMIC_SUBGRAPH, GE_LOCAL, CONTROL_OP, HCCL, RTS, HOST_CPU, RESERVED }; static NodeExecutorManager &GetInstance() { static NodeExecutorManager instance; return instance; } /** * Register build of executor * @param executor_type type of executor * @param builder build function */ void RegisterExecutorBuilder(ExecutorType executor_type, const std::function &builder); /** * Initialize executor if needed * @return SUCCESS on success, error code otherwise */ Status EnsureInitialized(); Status InitializeExecutors(); void FinalizeExecutors(); /** * CalcOpRunningParam * @param node node * @return SUCCESS on success, error code otherwise */ Status CalcOpRunningParam(Node &node) const; /** * Get executor by node * @param node node * @param executor executor * @return SUCCESS on success, error code otherwise */ Status GetExecutor(Node &node, const NodeExecutor **executor) const; /** * Resolve executor type by node * @param node node * @return executor type */ ExecutorType ResolveExecutorType(Node &node) const; private: std::map> executors_; std::map> builders_; std::map engine_mapping_; std::mutex mu_; bool initialized_ = false; bool executor_initialized_ = false; int ref_count_ = 0; }; class NodeExecutorRegistrar { public: NodeExecutorRegistrar(NodeExecutorManager::ExecutorType executor_type, NodeExecutor *(*builder)()); ~NodeExecutorRegistrar() = default; }; } // namespace hybrid } // namespace ge #define REGISTER_NODE_EXECUTOR_BUILDER(engine_type, executor) \ REGISTER_NODE_EXECUTOR_BUILDER_UNIQ_HELPER(__COUNTER__, engine_type, executor) #define REGISTER_NODE_EXECUTOR_BUILDER_UNIQ_HELPER(ctr, engine_type, executor) \ REGISTER_NODE_EXECUTOR_BUILDER_UNIQ(ctr, engine_type, executor) #define REGISTER_NODE_EXECUTOR_BUILDER_UNIQ(ctr, engine_type, executor) \ static ::ge::hybrid::NodeExecutorRegistrar register_##ctr \ __attribute__((unused)) = \ ::ge::hybrid::NodeExecutorRegistrar(engine_type, []()->::ge::hybrid::NodeExecutor* { \ return new (std::nothrow) executor(); \ }) #endif // GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_