!14133 tensorprint_debug

From: @yepei6
Reviewed-by: 
Signed-off-by:
pull/14133/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0ef2d78411

@ -14,31 +14,37 @@
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/engine/tdt/tdt_handle.h" #include "minddata/dataset/engine/tdt/tdt_handle.h"
namespace mindspore { namespace mindspore {
extern std::set<void **> acl_handle_set;
namespace dataset { namespace dataset {
std::vector<acltdtChannelHandle *> TdtHandle::acl_handle = std::vector<acltdtChannelHandle *>(); void TdtHandle::AddHandle(acltdtChannelHandle **handle) {
if (*handle != nullptr) {
void TdtHandle::AddHandle(acltdtChannelHandle *handle) { acl_handle_set.insert(reinterpret_cast<void **>(handle));
if (handle != nullptr) {
acl_handle.emplace_back(handle);
} }
} }
void TdtHandle::DelHandle(acltdtChannelHandle **handle) {
void **void_handle = reinterpret_cast<void **>(handle);
acl_handle_set.erase(void_handle);
}
bool TdtHandle::DestroyHandle() { bool TdtHandle::DestroyHandle() {
bool destroy_all = true; bool destroy_all = true;
for (auto &handle : acl_handle) { for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) {
if (handle != nullptr) { acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(*it);
if (acltdtDestroyChannel(handle) != ACL_SUCCESS) { if (*handle != nullptr) {
acltdtStopChannel(*handle);
if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) {
destroy_all = false; destroy_all = false;
} else { } else {
handle = nullptr; *handle = nullptr;
} }
} }
} }
return destroy_all; return destroy_all;
} }
std::vector<acltdtChannelHandle *> TdtHandle::GetHandle() { return acl_handle; }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -17,23 +17,21 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
#include <iostream> #include <iostream>
#include <vector> #include <set>
#include "acl/acl_tdt.h" #include "acl/acl_tdt.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class TdtHandle { class TdtHandle {
public: public:
static void AddHandle(acltdtChannelHandle *handle); static void AddHandle(acltdtChannelHandle **handle);
static bool DestroyHandle(); static bool DestroyHandle();
static std::vector<acltdtChannelHandle *> GetHandle(); static void DelHandle(acltdtChannelHandle **handle);
private: private:
TdtHandle() {} TdtHandle() {}
static std::vector<acltdtChannelHandle *> acl_handle;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -29,17 +29,14 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) {
if (acl_handle_ == nullptr) { if (acl_handle_ == nullptr) {
MS_LOG(ERROR) << "Failed to create channel for tdt queue."; MS_LOG(ERROR) << "Failed to create channel for tdt queue.";
} }
TdtHandle::AddHandle(acl_handle_); TdtHandle::AddHandle(&acl_handle_);
} }
TdtPlugin::~TdtPlugin() { TdtPlugin::~TdtPlugin() {
std::vector<acltdtChannelHandle *> total_handle = TdtHandle::GetHandle();
if (std::find(total_handle.begin(), total_handle.end(), acl_handle_) != total_handle.end()) {
if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; MS_LOG(ERROR) << "Failed to destroy channel for tdt queue.";
} }
} }
}
Status TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time, Status TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,
acltdtTensorType tdt_type) { acltdtTensorType tdt_type) {

@ -78,7 +78,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
} }
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle();
if (acl_handle == nullptr) { if (acl_handle == nullptr) {
MS_LOG(EXCEPTION) << "Get acltdt handle failed"; MS_LOG(EXCEPTION) << "Get acltdt handle failed";
return false; return false;
@ -92,7 +92,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ms_context_ptr == nullptr) { if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
} }
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) { if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
return true; return true;
@ -102,22 +102,8 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0); ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); ms_context_ptr->DestroyAclTdtChannelHandle();
aclError stopStatus = acltdtStopChannel(acl_handle);
if (stopStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed stop acl data channel for host queue ";
} else {
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
}
MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op ";
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
aclError destrodStatus = acltdtDestroyChannel(acl_handle);
if (destrodStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op ";
} else {
MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op ";
}
try { try {
if (ms_context_ptr->acl_tdt_print.joinable()) { if (ms_context_ptr->acl_tdt_print.joinable()) {
MS_LOG(INFO) << "join acl tdt host receive process"; MS_LOG(INFO) << "join acl tdt host receive process";

@ -17,6 +17,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
std::set<void **> acl_handle_set = std::set<void **>();
// set default log level to WARNING for all sub modules // set default log level to WARNING for all sub modules
int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING}; int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING};
} // namespace mindspore } // namespace mindspore

@ -22,6 +22,7 @@
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <memory> #include <memory>
#include <set>
#include <functional> #include <functional>
#include "utils/overload.h" #include "utils/overload.h"
#include "./securec.h" #include "./securec.h"
@ -41,6 +42,7 @@ static constexpr size_t GetRelPathPos() noexcept {
} }
namespace mindspore { namespace mindspore {
extern std::set<void **> acl_handle_set __attribute__((visibility("default")));
#define FILE_NAME \ #define FILE_NAME \
(sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \ (sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \
: static_cast<const char *>(__FILE__)) : static_cast<const char *>(__FILE__))

@ -109,6 +109,43 @@ bool MsContext::set_backend_policy(const std::string &policy) {
return true; return true;
} }
#ifdef ENABLE_TDTQUE
namespace py = pybind11;
acltdtChannelHandle *MsContext::CreateAclTdtChannelHandle() {
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
acltdtChannelHandle *acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle != nullptr) {
MS_LOG(INFO) << "Success to create acltdt handle.";
acl_handle_ = acl_handle;
TdtHandle::AddHandle(&acl_handle_);
}
return acl_handle;
}
void MsContext::DestroyAclTdtChannelHandle() {
if (acl_handle_ == nullptr) {
MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr";
return;
}
aclError stopStatus = acltdtStopChannel(acl_handle_);
if (stopStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl;
return;
}
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
aclError destroydStatus = acltdtDestroyChannel(acl_handle_);
if (destroydStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl;
return;
}
TdtHandle::DelHandle(&acl_handle_);
MS_LOG(INFO) << "Succeed destroy acl channel";
}
#endif
std::string MsContext::backend_policy() const { std::string MsContext::backend_policy() const {
auto res = std::find_if( auto res = std::find_if(
policy_map_.begin(), policy_map_.end(), policy_map_.begin(), policy_map_.end(),
@ -127,21 +164,4 @@ bool MsContext::enable_dump_ir() const {
#endif #endif
} }
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() {
if (acl_handle == nullptr) {
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle == nullptr) {
MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name;
return nullptr;
}
MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name;
return acl_handle;
}
return acl_handle;
}
#endif
} // namespace mindspore } // namespace mindspore

@ -25,9 +25,15 @@
#include <utility> #include <utility>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#ifdef ENABLE_TDTQUE
#include "pybind11/pybind11.h"
#include "mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h"
using mindspore::dataset::TdtHandle;
#endif
#ifndef NO_DLIB #ifndef NO_DLIB
#include "acl/acl_tdt.h" #include "acl/acl_tdt.h"
#endif #endif
namespace mindspore { namespace mindspore {
enum MsBackendPolicy { enum MsBackendPolicy {
kMsBackendGeOnly = 0, kMsBackendGeOnly = 0,
@ -137,7 +143,8 @@ class MsContext {
std::string backend_policy() const; std::string backend_policy() const;
bool set_backend_policy(const std::string &policy); bool set_backend_policy(const std::string &policy);
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
acltdtChannelHandle *get_acl_tdt_channel_handle(); acltdtChannelHandle *CreateAclTdtChannelHandle();
void DestroyAclTdtChannelHandle();
#endif #endif
static void device_seter(DeviceSeter device) { seter_ = device; } static void device_seter(DeviceSeter device) { seter_ = device; }
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
@ -175,10 +182,9 @@ class MsContext {
uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS];
float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
MsBackendPolicy backend_policy_; MsBackendPolicy backend_policy_;
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = nullptr; acltdtChannelHandle *acl_handle_ = nullptr;
#endif #endif
}; };

@ -14,9 +14,6 @@
* limitations under the License. * limitations under the License.
*/ */
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include <string>
#include <vector>
#include <atomic>
namespace mindspore { namespace mindspore {
namespace common { namespace common {

@ -19,6 +19,8 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <string> #include <string>
#include <vector>
#include <atomic>
#define DISABLE_COPY_AND_ASSIGN(ClassType) \ #define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType &) = delete; \ ClassType(const ClassType &) = delete; \

Loading…
Cancel
Save