!14133 tensorprint_debug

From: @yepei6
Reviewed-by: 
Signed-off-by:
pull/14133/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0ef2d78411

@ -14,31 +14,37 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/tdt/tdt_handle.h"
namespace mindspore {
extern std::set<void **> acl_handle_set;
namespace dataset {
std::vector<acltdtChannelHandle *> TdtHandle::acl_handle = std::vector<acltdtChannelHandle *>();
void TdtHandle::AddHandle(acltdtChannelHandle *handle) {
if (handle != nullptr) {
acl_handle.emplace_back(handle);
void TdtHandle::AddHandle(acltdtChannelHandle **handle) {
if (*handle != nullptr) {
acl_handle_set.insert(reinterpret_cast<void **>(handle));
}
}
void TdtHandle::DelHandle(acltdtChannelHandle **handle) {
void **void_handle = reinterpret_cast<void **>(handle);
acl_handle_set.erase(void_handle);
}
bool TdtHandle::DestroyHandle() {
bool destroy_all = true;
for (auto &handle : acl_handle) {
if (handle != nullptr) {
if (acltdtDestroyChannel(handle) != ACL_SUCCESS) {
for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) {
acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(*it);
if (*handle != nullptr) {
acltdtStopChannel(*handle);
if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) {
destroy_all = false;
} else {
handle = nullptr;
*handle = nullptr;
}
}
}
return destroy_all;
}
std::vector<acltdtChannelHandle *> TdtHandle::GetHandle() { return acl_handle; }
} // namespace dataset
} // namespace mindspore

@ -17,23 +17,21 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
#include <iostream>
#include <vector>
#include <set>
#include "acl/acl_tdt.h"
namespace mindspore {
namespace dataset {
class TdtHandle {
public:
static void AddHandle(acltdtChannelHandle *handle);
static void AddHandle(acltdtChannelHandle **handle);
static bool DestroyHandle();
static std::vector<acltdtChannelHandle *> GetHandle();
static void DelHandle(acltdtChannelHandle **handle);
private:
TdtHandle() {}
static std::vector<acltdtChannelHandle *> acl_handle;
};
} // namespace dataset
} // namespace mindspore

@ -29,16 +29,13 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) {
if (acl_handle_ == nullptr) {
MS_LOG(ERROR) << "Failed to create channel for tdt queue.";
}
TdtHandle::AddHandle(acl_handle_);
TdtHandle::AddHandle(&acl_handle_);
}
TdtPlugin::~TdtPlugin() {
std::vector<acltdtChannelHandle *> total_handle = TdtHandle::GetHandle();
if (std::find(total_handle.begin(), total_handle.end(), acl_handle_) != total_handle.end()) {
if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed to destroy channel for tdt queue.";
}
}
}
Status TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,

@ -78,7 +78,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
}
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle();
acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle();
if (acl_handle == nullptr) {
MS_LOG(EXCEPTION) << "Get acltdt handle failed";
return false;
@ -92,7 +92,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
}
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
return true;
@ -102,22 +102,8 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle();
aclError stopStatus = acltdtStopChannel(acl_handle);
if (stopStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed stop acl data channel for host queue ";
} else {
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
}
MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op ";
ms_context_ptr->DestroyAclTdtChannelHandle();
py::gil_scoped_release gil_release;
aclError destrodStatus = acltdtDestroyChannel(acl_handle);
if (destrodStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op ";
} else {
MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op ";
}
try {
if (ms_context_ptr->acl_tdt_print.joinable()) {
MS_LOG(INFO) << "join acl tdt host receive process";

@ -17,6 +17,7 @@
#include "utils/log_adapter.h"
namespace mindspore {
std::set<void **> acl_handle_set = std::set<void **>();
// set default log level to WARNING for all sub modules
int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING};
} // namespace mindspore

@ -22,6 +22,7 @@
#include <string>
#include <sstream>
#include <memory>
#include <set>
#include <functional>
#include "utils/overload.h"
#include "./securec.h"
@ -41,6 +42,7 @@ static constexpr size_t GetRelPathPos() noexcept {
}
namespace mindspore {
extern std::set<void **> acl_handle_set __attribute__((visibility("default")));
#define FILE_NAME \
(sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \
: static_cast<const char *>(__FILE__))

@ -109,6 +109,43 @@ bool MsContext::set_backend_policy(const std::string &policy) {
return true;
}
#ifdef ENABLE_TDTQUE
namespace py = pybind11;
acltdtChannelHandle *MsContext::CreateAclTdtChannelHandle() {
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
acltdtChannelHandle *acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle != nullptr) {
MS_LOG(INFO) << "Success to create acltdt handle.";
acl_handle_ = acl_handle;
TdtHandle::AddHandle(&acl_handle_);
}
return acl_handle;
}
void MsContext::DestroyAclTdtChannelHandle() {
if (acl_handle_ == nullptr) {
MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr";
return;
}
aclError stopStatus = acltdtStopChannel(acl_handle_);
if (stopStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl;
return;
}
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
aclError destroydStatus = acltdtDestroyChannel(acl_handle_);
if (destroydStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl;
return;
}
TdtHandle::DelHandle(&acl_handle_);
MS_LOG(INFO) << "Succeed destroy acl channel";
}
#endif
std::string MsContext::backend_policy() const {
auto res = std::find_if(
policy_map_.begin(), policy_map_.end(),
@ -127,21 +164,4 @@ bool MsContext::enable_dump_ir() const {
#endif
}
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() {
if (acl_handle == nullptr) {
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle == nullptr) {
MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name;
return nullptr;
}
MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name;
return acl_handle;
}
return acl_handle;
}
#endif
} // namespace mindspore

@ -25,9 +25,15 @@
#include <utility>
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#ifdef ENABLE_TDTQUE
#include "pybind11/pybind11.h"
#include "mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h"
using mindspore::dataset::TdtHandle;
#endif
#ifndef NO_DLIB
#include "acl/acl_tdt.h"
#endif
namespace mindspore {
enum MsBackendPolicy {
kMsBackendGeOnly = 0,
@ -137,7 +143,8 @@ class MsContext {
std::string backend_policy() const;
bool set_backend_policy(const std::string &policy);
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *get_acl_tdt_channel_handle();
acltdtChannelHandle *CreateAclTdtChannelHandle();
void DestroyAclTdtChannelHandle();
#endif
static void device_seter(DeviceSeter device) { seter_ = device; }
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
@ -175,10 +182,9 @@ class MsContext {
uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS];
float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
MsBackendPolicy backend_policy_;
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = nullptr;
acltdtChannelHandle *acl_handle_ = nullptr;
#endif
};

@ -14,9 +14,6 @@
* limitations under the License.
*/
#include "utils/ms_utils.h"
#include <string>
#include <vector>
#include <atomic>
namespace mindspore {
namespace common {

@ -19,6 +19,8 @@
#include <memory>
#include <utility>
#include <string>
#include <vector>
#include <atomic>
#define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType &) = delete; \

Loading…
Cancel
Save