/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_COMMON_THREAD_POOL_H_ #define GE_COMMON_THREAD_POOL_H_ #include #include #include #include #include #include #include #include #include #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "external/ge/ge_api_error_codes.h" #include "graph/types.h" #include "common/ge/ge_util.h" namespace ge { using ThreadTask = std::function; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { public: explicit ThreadPool(uint32_t size = 4); ~ThreadPool(); template auto commit(Func &&func, Args &&... args) -> std::future { GELOGD("commit run task enter."); using retType = decltype(func(args...)); std::future fail_future; if (is_stoped_.load()) { GELOGE(ge::FAILED, "thread pool has been stopped."); return fail_future; } auto bindFunc = std::bind(std::forward(func), std::forward(args)...); auto task = ge::MakeShared>(bindFunc); if (task == nullptr) { GELOGE(ge::FAILED, "Make shared failed."); return fail_future; } std::future future = task->get_future(); { std::lock_guard lock{m_lock_}; tasks_.emplace([task]() { (*task)(); }); } cond_var_.notify_one(); GELOGD("commit run task end"); return future; } static void ThreadFunc(ThreadPool *thread_pool); private: std::vector pool_; std::queue tasks_; std::mutex m_lock_; std::condition_variable cond_var_; std::atomic is_stoped_; std::atomic idle_thrd_num_; }; } // namespace ge #endif // GE_COMMON_THREAD_POOL_H_