diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc index 0860e97584..f523387735 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc @@ -14,31 +14,37 @@ * limitations under the License. */ #include "minddata/dataset/engine/tdt/tdt_handle.h" + namespace mindspore { +extern std::set acl_handle_set; namespace dataset { -std::vector TdtHandle::acl_handle = std::vector(); - -void TdtHandle::AddHandle(acltdtChannelHandle *handle) { - if (handle != nullptr) { - acl_handle.emplace_back(handle); +void TdtHandle::AddHandle(acltdtChannelHandle **handle) { + if (*handle != nullptr) { + acl_handle_set.insert(reinterpret_cast(handle)); } } +void TdtHandle::DelHandle(acltdtChannelHandle **handle) { + void **void_handle = reinterpret_cast(handle); + acl_handle_set.erase(void_handle); +} + bool TdtHandle::DestroyHandle() { bool destroy_all = true; - for (auto &handle : acl_handle) { - if (handle != nullptr) { - if (acltdtDestroyChannel(handle) != ACL_SUCCESS) { + for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) { + acltdtChannelHandle **handle = reinterpret_cast(*it); + if (*handle != nullptr) { + acltdtStopChannel(*handle); + if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) { destroy_all = false; } else { - handle = nullptr; + *handle = nullptr; } } } return destroy_all; } -std::vector TdtHandle::GetHandle() { return acl_handle; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h index 5cabf8b0ec..3774e8373b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h @@ -17,23 +17,21 @@ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_ #include -#include +#include #include "acl/acl_tdt.h" namespace mindspore { namespace dataset { class TdtHandle { public: - static void AddHandle(acltdtChannelHandle *handle); + static void AddHandle(acltdtChannelHandle **handle); static bool DestroyHandle(); - static std::vector GetHandle(); + static void DelHandle(acltdtChannelHandle **handle); private: TdtHandle() {} - - static std::vector acl_handle; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc index 940a185de2..bcd8628878 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -29,15 +29,12 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) { if (acl_handle_ == nullptr) { MS_LOG(ERROR) << "Failed to create channel for tdt queue."; } - TdtHandle::AddHandle(acl_handle_); + TdtHandle::AddHandle(&acl_handle_); } TdtPlugin::~TdtPlugin() { - std::vector total_handle = TdtHandle::GetHandle(); - if (std::find(total_handle.begin(), total_handle.end(), acl_handle_) != total_handle.end()) { - if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { - MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; - } + if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { + MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; } } diff --git a/mindspore/ccsrc/utils/context/context_extends.cc b/mindspore/ccsrc/utils/context/context_extends.cc index 6e3edc87ba..c0279f0f8c 100644 --- a/mindspore/ccsrc/utils/context/context_extends.cc +++ b/mindspore/ccsrc/utils/context/context_extends.cc @@ -78,7 +78,7 @@ bool OpenTsd(const std::shared_ptr &ms_context_ptr) { } ms_context_ptr->increase_param(MS_CTX_TSD_REF); #ifdef ENABLE_TDTQUE - acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); + acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle(); if (acl_handle == nullptr) { MS_LOG(EXCEPTION) << "Get acltdt handle failed"; return false; @@ -92,7 +92,7 @@ bool OpenTsd(const std::shared_ptr &ms_context_ptr) { bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { if (ms_context_ptr == nullptr) { - MS_LOG(EXCEPTION) << "nullptr"; + MS_LOG(EXCEPTION) << "ms_context_prt is nullptr"; } if (ms_context_ptr->get_param(MS_CTX_TSD_REF) == 0) { return true; @@ -102,22 +102,8 @@ bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { ms_context_ptr->set_param(MS_CTX_TSD_REF, 0); #ifdef ENABLE_TDTQUE - acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); - aclError stopStatus = acltdtStopChannel(acl_handle); - if (stopStatus != ACL_SUCCESS) { - MS_LOG(ERROR) << "Failed stop acl data channel for host queue "; - } else { - MS_LOG(INFO) << "Succeed stop acl data channel for host queue "; - } - MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op "; - + ms_context_ptr->DestroyAclTdtChannelHandle(); py::gil_scoped_release gil_release; - aclError destrodStatus = acltdtDestroyChannel(acl_handle); - if (destrodStatus != ACL_SUCCESS) { - MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op "; - } else { - MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op "; - } try { if (ms_context_ptr->acl_tdt_print.joinable()) { MS_LOG(INFO) << "join acl tdt host receive process"; diff --git a/mindspore/core/gvar/logging_level.cc b/mindspore/core/gvar/logging_level.cc index 688e6ddcb8..5c435454a7 100644 --- a/mindspore/core/gvar/logging_level.cc +++ b/mindspore/core/gvar/logging_level.cc @@ -17,6 +17,7 @@ #include "utils/log_adapter.h" namespace mindspore { +std::set acl_handle_set = std::set(); // set default log level to WARNING for all sub modules int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING}; } // namespace mindspore diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index f6c067e016..1429f73fb3 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "utils/overload.h" #include "./securec.h" @@ -41,6 +42,7 @@ static constexpr size_t GetRelPathPos() noexcept { } namespace mindspore { +extern std::set acl_handle_set __attribute__((visibility("default"))); #define FILE_NAME \ (sizeof(__FILE__) > GetRelPathPos() ? static_cast(__FILE__) + GetRelPathPos() \ : static_cast(__FILE__)) diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index b39d75b522..7435f97fd5 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -109,6 +109,43 @@ bool MsContext::set_backend_policy(const std::string &policy) { return true; } +#ifdef ENABLE_TDTQUE +namespace py = pybind11; +acltdtChannelHandle *MsContext::CreateAclTdtChannelHandle() { + uint32_t device_id = get_param(MS_CTX_DEVICE_ID); + std::string kReceivePrefix = "TF_RECEIVE_"; + std::string channel_name = "_npu_log"; + acltdtChannelHandle *acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str()); + if (acl_handle != nullptr) { + MS_LOG(INFO) << "Success to create acltdt handle."; + acl_handle_ = acl_handle; + TdtHandle::AddHandle(&acl_handle_); + } + return acl_handle; +} + +void MsContext::DestroyAclTdtChannelHandle() { + if (acl_handle_ == nullptr) { + MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr"; + return; + } + aclError stopStatus = acltdtStopChannel(acl_handle_); + if (stopStatus != ACL_SUCCESS) { + MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl; + return; + } + MS_LOG(INFO) << "Succeed stop acl data channel for host queue "; + + aclError destroydStatus = acltdtDestroyChannel(acl_handle_); + if (destroydStatus != ACL_SUCCESS) { + MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl; + return; + } + TdtHandle::DelHandle(&acl_handle_); + MS_LOG(INFO) << "Succeed destroy acl channel"; +} +#endif + std::string MsContext::backend_policy() const { auto res = std::find_if( policy_map_.begin(), policy_map_.end(), @@ -127,21 +164,4 @@ bool MsContext::enable_dump_ir() const { #endif } -#ifdef ENABLE_TDTQUE -acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() { - if (acl_handle == nullptr) { - std::string kReceivePrefix = "TF_RECEIVE_"; - std::string channel_name = "_npu_log"; - uint32_t device_id = get_param(MS_CTX_DEVICE_ID); - acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str()); - if (acl_handle == nullptr) { - MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name; - return nullptr; - } - MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name; - return acl_handle; - } - return acl_handle; -} -#endif } // namespace mindspore diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 4f39b2ea50..119bb2e464 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -25,9 +25,15 @@ #include #include "utils/log_adapter.h" #include "utils/ms_utils.h" +#ifdef ENABLE_TDTQUE +#include "pybind11/pybind11.h" +#include "mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h" +using mindspore::dataset::TdtHandle; +#endif #ifndef NO_DLIB #include "acl/acl_tdt.h" #endif + namespace mindspore { enum MsBackendPolicy { kMsBackendGeOnly = 0, @@ -137,7 +143,8 @@ class MsContext { std::string backend_policy() const; bool set_backend_policy(const std::string &policy); #ifdef ENABLE_TDTQUE - acltdtChannelHandle *get_acl_tdt_channel_handle(); + acltdtChannelHandle *CreateAclTdtChannelHandle(); + void DestroyAclTdtChannelHandle(); #endif static void device_seter(DeviceSeter device) { seter_ = device; } static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } @@ -175,10 +182,9 @@ class MsContext { uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; - MsBackendPolicy backend_policy_; #ifdef ENABLE_TDTQUE - acltdtChannelHandle *acl_handle = nullptr; + acltdtChannelHandle *acl_handle_ = nullptr; #endif }; diff --git a/mindspore/core/utils/ms_utils.cc b/mindspore/core/utils/ms_utils.cc index f6a1567a7b..3905f8fbc2 100644 --- a/mindspore/core/utils/ms_utils.cc +++ b/mindspore/core/utils/ms_utils.cc @@ -14,9 +14,6 @@ * limitations under the License. */ #include "utils/ms_utils.h" -#include -#include -#include namespace mindspore { namespace common { diff --git a/mindspore/core/utils/ms_utils.h b/mindspore/core/utils/ms_utils.h index bf85f47dc9..cccfb117f2 100644 --- a/mindspore/core/utils/ms_utils.h +++ b/mindspore/core/utils/ms_utils.h @@ -19,6 +19,8 @@ #include #include #include +#include +#include #define DISABLE_COPY_AND_ASSIGN(ClassType) \ ClassType(const ClassType &) = delete; \