/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_SINGLE_OP_TASK_TBE_TASK_BUILDER_H_ #define GE_SINGLE_OP_TASK_TBE_TASK_BUILDER_H_ #include #include #include #include #include #include "graph/op_desc.h" #include "single_op/single_op.h" #include "single_op/single_op_model.h" namespace ge { class KernelHolder { public: KernelHolder(const char *stub_func, std::shared_ptr kernel_bin); ~KernelHolder(); void SetBinHandle(void *bin_handle) { bin_handle_ = bin_handle; } private: friend class KernelBinRegistry; const char *stub_func_; void *bin_handle_; std::shared_ptr kernel_bin_; }; class HandleHolder { public: HandleHolder(void *bin_handle); ~HandleHolder(); void SetBinHandle(void *bin_handle) { bin_handle_ = bin_handle; } void *GetBinHandle() { return bin_handle_; } private: friend class HandleRegistry; void *bin_handle_ = nullptr; }; class KernelBinRegistry { public: static KernelBinRegistry &GetInstance() { static KernelBinRegistry instance; return instance; } const char *GetUnique(const string &stub_func); const char *GetStubFunc(const std::string &stub_name); bool AddKernel(const std::string &stub_name, std::unique_ptr &&holder); private: std::map> registered_bins_; std::set unique_stubs_; std::mutex mutex_; }; class HandleRegistry { public: static HandleRegistry &GetInstance() { static HandleRegistry instance; return instance; } bool AddHandle(std::unique_ptr &&holder); private: std::set> registered_handles_; }; class TbeTaskBuilder { public: TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::TaskDef &task_def); ~TbeTaskBuilder() = default; Status BuildTask(TbeOpTask &task, const SingleOpModelParam ¶m); private: Status InitTilingInfo(TbeOpTask &task); Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc); Status SetKernelWithHandleArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc); Status GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m) const; Status RegisterKernel(TbeOpTask &task, const SingleOpModelParam ¶m); Status RegisterKernelWithHandle(TbeOpTask &task, const SingleOpModelParam ¶m); Status DoRegisterKernel(const OpKernelBin &kernel_bin, const char *bin_file_key, void **bin_handle, const SingleOpModelParam ¶m); Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; Status DoRegisterMeta(void *bin_handle); static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); const NodePtr node_; const OpDescPtr op_desc_; const domi::TaskDef &task_def_; const domi::KernelDef &kernel_def_; const domi::KernelDefWithHandle &kernel_def_with_handle_; const std::string stub_name_; void *handle_ = nullptr; }; } // namespace ge #endif // GE_SINGLE_OP_TASK_TBE_TASK_BUILDER_H_