回退 'Pull Request !11074 : replace tdt with acltdt interface'

pull/12198/head
gongxiaoqing 4 years ago committed by Gitee
parent 26d4b99019
commit 7f538b51e7

@ -267,8 +267,6 @@ if(ENABLE_D)
find_library(REGISTER register ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(REGISTER register ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH}) find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH})
find_library(ACL ascendcl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
# hccl_adpter # hccl_adpter
find_library(HCCL_ADPTER hcom_graph_adaptor ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(HCCL_ADPTER hcom_graph_adaptor ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(HCCL_RA ra ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(HCCL_RA ra ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
@ -283,7 +281,7 @@ if(ENABLE_D)
mindspore::protobuf -Wl,--end-group) mindspore::protobuf -Wl,--end-group)
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}
${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER} ${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER}
${HCCL_RA} ${PLATFORM} ${ACL}) ${HCCL_RA} ${PLATFORM})
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group) target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)
elseif(CMAKE_SYSTEM_NAME MATCHES "Windows") elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece

@ -264,7 +264,7 @@ if(ENABLE_GPUQUE)
endif() endif()
if(ENABLE_TDTQUE) if(ENABLE_TDTQUE)
target_link_libraries(_c_dataengine PRIVATE ${ACL}) target_link_libraries(_c_dataengine PRIVATE ${TSDCLIENT})
endif() endif()
add_dependencies(_c_dataengine _c_mindrecord) add_dependencies(_c_dataengine _c_mindrecord)

@ -131,8 +131,8 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Function to return a transferred Node that transfers data through a device. // Function to return a transferred Node that transfers data through a device.
bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t device_id, int32_t num_epochs, bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t num_epochs, bool send_epoch_end,
bool send_epoch_end, int32_t total_batches, bool create_data_info_queue) { int32_t total_batches, bool create_data_info_queue) {
Status rc; Status rc;
// Build and launch tree // Build and launch tree
@ -144,8 +144,8 @@ bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32
} }
// Add TransferNode IR on top of dataset // Add TransferNode IR on top of dataset
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, device_id, auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, send_epoch_end,
send_epoch_end, total_batches, create_data_info_queue); total_batches, create_data_info_queue);
// Get ToDevice consumer // Get ToDevice consumer
auto consumer = std::make_unique<ToDevice>(num_epochs); auto consumer = std::make_unique<ToDevice>(num_epochs);

@ -521,10 +521,9 @@ PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode", (void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
"to create a TransferNode") "to create a TransferNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type, .def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type,
int32_t device_id, bool send_epoch_end, int32_t total_batch, bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) {
bool create_data_info_queue) { auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end,
auto transfer = std::make_shared<TransferNode>( total_batch, create_data_info_queue);
self, queue_name, device_type, device_id, send_epoch_end, total_batch, create_data_info_queue);
THROW_IF_ERROR(transfer->ValidateParams()); THROW_IF_ERROR(transfer->ValidateParams());
return transfer; return transfer;
})); }));

@ -55,7 +55,6 @@ DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, i
#endif #endif
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
ascend_keep_waiting_ = true; ascend_keep_waiting_ = true;
tdtInstancePtr = std::make_shared<TdtPlugin>(channel_name_, device_id_);
#endif #endif
} }
@ -153,7 +152,7 @@ Status DeviceQueueOp::SendDataToAscend() {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
WaitContinueSignal(); WaitContinueSignal();
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
if (status != Status::OK()) { if (status == TdtStatus::FAILED) {
if (stop_send_) { if (stop_send_) {
MS_LOG(INFO) << "stop_send received"; MS_LOG(INFO) << "stop_send received";
return Status::OK(); return Status::OK();
@ -184,9 +183,9 @@ Status DeviceQueueOp::SendDataToAscend() {
} }
if (current_buffer->eoe() && send_epoch_end_) { if (current_buffer->eoe() && send_epoch_end_) {
TensorRow currRow; TensorRow currRow;
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, auto status =
ACL_TENSOR_DATA_END_OF_SEQUENCE); tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
if (status != Status::OK()) { if (status == TdtStatus::FAILED) {
if (stop_send_) { if (stop_send_) {
MS_LOG(INFO) << "stop_send received"; MS_LOG(INFO) << "stop_send received";
return Status::OK(); return Status::OK();
@ -203,6 +202,7 @@ Status DeviceQueueOp::SendDataToAscend() {
} }
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
} }
tree_->SetFinished(); tree_->SetFinished();
return Status::OK(); return Status::OK();

@ -32,20 +32,20 @@ namespace dataset {
// Constructor for TransferNode // Constructor for TransferNode
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type,
int32_t device_id, bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
: prefetch_size_(16), : prefetch_size_(16),
queue_name_(std::move(queue_name)), queue_name_(std::move(queue_name)),
device_type_(std::move(device_type)), device_type_(std::move(device_type)),
send_epoch_end_(send_epoch_end), send_epoch_end_(send_epoch_end),
total_batch_(total_batch), total_batch_(total_batch),
create_data_info_queue_(create_data_info_queue), create_data_info_queue_(create_data_info_queue),
device_id_(device_id) { device_id_(0) {
this->AddChild(child); this->AddChild(child);
} }
std::shared_ptr<DatasetNode> TransferNode::Copy() { std::shared_ptr<DatasetNode> TransferNode::Copy() {
auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, device_id_, send_epoch_end_, auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_,
total_batch_, create_data_info_queue_); create_data_info_queue_);
return node; return node;
} }
@ -96,9 +96,9 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
// // Get device ID (shard ID) from children // Get device ID (shard ID) from children
// device_id_ = 0; device_id_ = 0;
// RETURN_IF_NOT_OK(this->GetShardId(&device_id_)); RETURN_IF_NOT_OK(this->GetShardId(&device_id_));
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
total_batch_, create_data_info_queue_); total_batch_, create_data_info_queue_);

@ -29,8 +29,8 @@ namespace dataset {
class TransferNode : public DatasetNode { class TransferNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, int32_t device_id, TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end,
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue); int32_t total_batch, bool create_data_info_queue);
/// \brief Destructor /// \brief Destructor
~TransferNode() = default; ~TransferNode() = default;

@ -1,6 +1,5 @@
file( file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
GLOB_RECURSE _CURRENT_SRC_FILES
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-tdt OBJECT tdt_plugin.cc tdt_handle.cc) add_library(engine-tdt OBJECT
tdt_plugin.cc
)

@ -1,39 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/tdt/tdt_handle.h"
namespace mindspore {
namespace dataset {
std::vector<acltdtChannelHandle *> TdtHandle::acl_handle = std::vector<acltdtChannelHandle *>();
void TdtHandle::AddHandle(acltdtChannelHandle *handle) {
if (handle != nullptr) {
acl_handle.emplace_back(handle);
}
}
bool TdtHandle::DestroyHandle() {
for (auto handle : acl_handle) {
if (handle != nullptr) {
if (acltdtDestroyChannel(handle) != ACL_SUCCESS) {
return false;
}
}
}
return true;
}
} // namespace dataset
} // namespace mindspore

@ -1,38 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
#include <iostream>
#include <vector>
#include "acl/acl_tdt.h"
namespace mindspore {
namespace dataset {
class TdtHandle {
public:
static void AddHandle(acltdtChannelHandle *handle);
static bool DestroyHandle();
private:
TdtHandle() {}
static std::vector<acltdtChannelHandle *> acl_handle;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_

@ -23,138 +23,108 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) { static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr;
// create acl tdt handle
acl_handle_ = acltdtCreateChannel(device_id, channel_name.c_str());
if (acl_handle_ == nullptr) {
MS_LOG(ERROR) << "Failed to create channel for tdt queue.";
}
TdtHandle::AddHandle(acl_handle_);
}
TdtPlugin::~TdtPlugin() { std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { if (instance_ptr_ == nullptr) {
MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; instance_ptr_ = std::shared_ptr<TdtPlugin>(new TdtPlugin);
} }
return instance_ptr_;
} }
Status TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time, TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,
acltdtTensorType tdt_type) { tdt::TdtDataType tdt_type) {
MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; MS_LOG(DEBUG) << "TDT channel name is " << channel_name << ".";
std::vector<DataItem> items;
acltdtDataset *acl_dataset = nullptr;
double start_time; double start_time;
auto ret = translate(tdt_type, ts_row, &acl_dataset); if (tdt_type == tdt::TDT_TENSOR) {
if (ret != Status::OK()) { auto ret = translate(ts_row, items);
DestroyAclDataset(acl_dataset); if (ret != SUCCESS) {
RETURN_STATUS_UNEXPECTED("TDT converting tensor failed!"); MS_LOG(ERROR) << "TDT converting tensor failed!";
return FAILED;
}
} else if (tdt_type == tdt::TDT_END_OF_SEQUENCE) {
DataItem data_item;
data_item.dataType_ = tdt::TDT_END_OF_SEQUENCE;
items.emplace_back(data_item);
MS_LOG(INFO) << "TDT data type is TDT_END_OF_SEQUENCE";
} }
if (profiling) { if (profiling) {
start_time = ProfilingTime::GetCurMilliSecond(); start_time = ProfilingTime::GetCurMilliSecond();
} }
#if ENABLE_D #if ENABLE_D
// Data prefetch only when PS mode enables cache. // Data prefetch only when PS mode enables cache.
if (acltdtGetDatasetSize(acl_dataset) > 0) { if (items.size() > 0) {
acltdtDataItem *item0 = acltdtGetDataItem(acl_dataset, 0); if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, acltdtGetDataAddrFromItem(item0), return FAILED;
acltdtGetDataSizeFromItem(item0))) {
RETURN_STATUS_UNEXPECTED("PrefetchData failed in when pre-processing sending data.");
} }
} }
#endif #endif
auto status = acltdtSendTensor(acl_handle_, acl_dataset, -1); if (tdt::TdtHostPushData(channel_name, items) != 0) {
DestroyAclDataset(acl_dataset); return FAILED;
if (status != ACL_SUCCESS) {
RETURN_STATUS_UNEXPECTED("Tdt Send data failed.");
} }
if (profiling) { if (profiling) {
double end_time = ProfilingTime::GetCurMilliSecond(); double end_time = ProfilingTime::GetCurMilliSecond();
time = (int32_t)(end_time - start_time); time = (int32_t)(end_time - start_time);
} }
return Status::OK(); return SUCCESS;
} }
Status TdtPlugin::getTdtType(DataType d_type, aclDataType &datatype) { TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) {
switch (d_type.value()) { switch (d_type.value()) {
case DataType::DE_BOOL: case DataType::DE_BOOL:
datatype = ACL_BOOL; datatype = "bool";
break; break;
case DataType::DE_INT8: case DataType::DE_INT8:
datatype = ACL_INT8; datatype = "int8";
break; break;
case DataType::DE_UINT8: case DataType::DE_UINT8:
datatype = ACL_UINT8; datatype = "uint8";
break; break;
case DataType::DE_INT16: case DataType::DE_INT16:
datatype = ACL_INT16; datatype = "int16";
break; break;
case DataType::DE_UINT16: case DataType::DE_UINT16:
datatype = ACL_UINT16; datatype = "uint16";
break; break;
case DataType::DE_INT32: case DataType::DE_INT32:
datatype = ACL_INT32; datatype = "int32";
break; break;
case DataType::DE_UINT32: case DataType::DE_UINT32:
datatype = ACL_UINT32; datatype = "uint32";
break; break;
case DataType::DE_FLOAT16: case DataType::DE_FLOAT16:
datatype = ACL_FLOAT16; datatype = "float16";
break; break;
case DataType::DE_FLOAT32: case DataType::DE_FLOAT32:
datatype = ACL_FLOAT; datatype = "float32";
break; break;
case DataType::DE_FLOAT64: case DataType::DE_FLOAT64:
datatype = ACL_DOUBLE; datatype = "float64";
break; break;
case DataType::DE_INT64: case DataType::DE_INT64:
datatype = ACL_INT64; datatype = "int64";
break; break;
case DataType::DE_UINT64: case DataType::DE_UINT64:
datatype = ACL_UINT64; datatype = "uint64";
break; break;
default: default:
RETURN_STATUS_UNEXPECTED("Invalid data, got unexpected data type."); return FAILED;
} }
return Status::OK(); return SUCCESS;
} }
Status TdtPlugin::translate(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset **output_acl_dataset) { TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &items) {
auto acl_dataset = acltdtCreateDataset(); if (ts_row.size() == 0) {
if (acl_dataset == nullptr) { MS_LOG(ERROR) << "TDT the size of row is zero.";
RETURN_STATUS_UNEXPECTED("Create tdt dataset failed."); return SUCCESS;
}
auto status = AssembleTensor2AclDataset(tdt_type, ts_row, acl_dataset);
if (status != Status::OK()) {
DestroyAclDataset(acl_dataset);
RETURN_STATUS_UNEXPECTED("Assemble tensor row to tdt dataset failed.");
} }
*output_acl_dataset = acl_dataset;
return Status::OK();
}
Status TdtPlugin::AssembleTensor2AclDataset(acltdtTensorType tdt_type, const TensorRow &ts_row,
acltdtDataset *acl_dataset) {
if (tdt_type != ACL_TENSOR_DATA_TENSOR || ts_row.size() == 0) {
acltdtDataItem *acl_data = acltdtCreateDataItem(tdt_type, nullptr, 0, ACL_BOOL, nullptr, 0);
if (acl_data == nullptr) {
RETURN_STATUS_UNEXPECTED("Create data item failed when send data with type:" + std::to_string(tdt_type));
}
if (acltdtAddDataItem(acl_dataset, acl_data) != ACL_SUCCESS) {
if (acltdtDestroyDataItem(acl_data) != ACL_SUCCESS) {
MS_LOG(ERROR) << "Destroy data item failed when send data with type: " << tdt_type;
}
RETURN_STATUS_UNEXPECTED("Add data item to tdt dataset failed when send data.");
}
return Status::OK();
}
for (auto ts : ts_row) { for (auto ts : ts_row) {
aclDataType datatype; std::string datatype;
acltdtDataItem *acl_data = nullptr; TdtStatus status = getTdtType(ts->type(), datatype);
RETURN_IF_NOT_OK(getTdtType(ts->type(), datatype)); if (status != SUCCESS) {
return status;
}
TensorShape tsShape = ts->shape(); TensorShape tsShape = ts->shape();
std::string dataShapes = "["; std::string dataShapes = "[";
for (auto dim : tsShape.AsVector()) { for (auto dim : tsShape.AsVector()) {
@ -162,46 +132,18 @@ Status TdtPlugin::AssembleTensor2AclDataset(acltdtTensorType tdt_type, const Ten
} }
dataShapes.pop_back(); dataShapes.pop_back();
(void)dataShapes.append("]"); (void)dataShapes.append("]");
DataItem data_item;
std::shared_ptr<void> dataPtr = data_item.dataType_ = tdt::TDT_TENSOR;
data_item.tensorShape_ = dataShapes;
data_item.tensorType_ = datatype;
data_item.dataLen_ = ts->SizeInBytes();
data_item.dataPtr_ =
std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {}); std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {});
size_t dataLen = ts->SizeInBytes(); items.emplace_back(data_item);
const dsize_t dims = tsShape.Rank();
std::vector<int64_t> dataShape;
for (auto i = 0; i < dims; i++) {
dataShape.emplace_back(tsShape[i]);
}
acl_data = acltdtCreateDataItem(ACL_TENSOR_DATA_TENSOR, (tsShape.empty() ? nullptr : &dataShape[0]), dims, datatype,
dataPtr.get(), dataLen);
if (acl_data == nullptr) {
RETURN_STATUS_UNEXPECTED("Create data item failed when send data.");
}
if (acltdtAddDataItem(acl_dataset, acl_data) != ACL_SUCCESS) {
if (acltdtDestroyDataItem(acl_data) != ACL_SUCCESS) {
MS_LOG(ERROR) << "Destroy data item failed when send data with type ACL_TENSOR_DATA_TENSOR.";
}
RETURN_STATUS_UNEXPECTED("Add data item to tdt dataset failed when send data.");
}
MS_LOG(DEBUG) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes MS_LOG(DEBUG) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes
<< ", data length is " << ts->Size() << "."; << ", data length is " << ts->Size() << ".";
} }
return SUCCESS;
return Status::OK();
}
Status TdtPlugin::DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item) {
if (include_data_item) {
for (size_t i = 0; i < acltdtGetDatasetSize(acl_dataset); i++) {
if (acltdtDestroyDataItem(acltdtGetDataItem(acl_dataset, i)) != ACL_SUCCESS) {
RETURN_STATUS_UNEXPECTED("Destroy data item failed when send data.");
}
}
}
if (acltdtDestroyDataset(acl_dataset) != ACL_SUCCESS) {
RETURN_STATUS_UNEXPECTED("Destroy tdt dataset failed when send data.");
}
return Status::OK();
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -22,40 +22,33 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "acl/acl_tdt.h" #include "tdt/tdt_host_interface.h"
#include "minddata/dataset/engine/tdt/tdt_handle.h"
#include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h" #include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
enum TdtStatus { SUCCESS, FAILED };
using tdt::DataItem;
class TdtPlugin { class TdtPlugin {
public: public:
static std::shared_ptr<TdtPlugin> GetInstance(); static std::shared_ptr<TdtPlugin> GetInstance();
Status hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time, TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time,
acltdtTensorType tdt_type = ACL_TENSOR_DATA_TENSOR); tdt::TdtDataType tdt_type = tdt::TDT_TENSOR);
TdtPlugin(const std::string &channel_name, int32_t device_id);
~TdtPlugin();
private: private:
Status DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true); TdtPlugin() {}
Status AssembleTensor2AclDataset(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset *acl_dataset); TdtStatus getTdtType(DataType d_type, std::string &datatype);
Status getTdtType(DataType d_type, aclDataType &datatype); TdtStatus translate(const TensorRow &ts_row, std::vector<DataItem> &items);
Status translate(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset **output_acl_dataset);
void *tdt_handle_ = nullptr; void *tdt_handle_ = nullptr;
acltdtChannelHandle *acl_handle_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -152,16 +152,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// of data transmission per time is 256M. /// of data transmission per time is 256M.
/// \param[in] queue_name Channel name (default="", create new unique name). /// \param[in] queue_name Channel name (default="", create new unique name).
/// \param[in] device_type Type of device (default="", get from MSContext). /// \param[in] device_type Type of device (default="", get from MSContext).
/// \param[in] device_id id of device (default=0, get from MSContext).
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs). /// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true). /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data). /// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
/// of data or not(default=false). /// of data or not(default=false).
/// \return Returns true if no error encountered else false. /// \return Returns true if no error encountered else false.
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0, bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1,
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0, bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false);
bool create_data_info_queue = false);
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
/// \note Usage restrictions: /// \note Usage restrictions:

@ -23,9 +23,8 @@
#include "minddata/dataset/util/services.h" #include "minddata/dataset/util/services.h"
#endif #endif
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
#include "acl/acl_tdt.h" #include "tdt/tdt_host_interface.h"
#include "tdt/status.h" #include "tdt/status.h"
#include "minddata/dataset/engine/tdt/tdt_handle.h"
#endif #endif
namespace mindspore { namespace mindspore {
@ -164,10 +163,11 @@ Status Task::Join(WaitFlag blocking) {
if (wait_times > 5 && my_name_.find("DeviceQueueOp") != std::string::npos) { if (wait_times > 5 && my_name_.find("DeviceQueueOp") != std::string::npos) {
MS_LOG(WARNING) << "Wait " << wait_times << " seconds, " MS_LOG(WARNING) << "Wait " << wait_times << " seconds, "
<< "the task: " << my_name_ << " will be destroyed by TdtHostDestory."; << "the task: " << my_name_ << " will be destroyed by TdtHostDestory.";
if (!TdtHandle::DestroyHandle()) { int32_t destory_status = tdt::TdtHostDestroy();
MS_LOG(WARNING) << "Destroy tdt channel failed."; if (destory_status != TDT_OK_CODE) {
MS_LOG(WARNING) << "Destroy tsd failed, status = " << destory_status << ".";
} else { } else {
MS_LOG(INFO) << "Destroy tdt channel success."; MS_LOG(INFO) << "Destroy tsd success.";
} }
// just wait 30 seconds // just wait 30 seconds

@ -1,6 +1,5 @@
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" "kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
) )
if(ENABLE_GPU) if(ENABLE_GPU)
@ -10,8 +9,7 @@ else()
endif() endif()
if(ENABLE_D) if(ENABLE_D)
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc" file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc")
"../../minddata/dataset/engine/tdt/tdt_handle.cc")
endif() endif()
if(ENABLE_CPU) if(ENABLE_CPU)

@ -54,8 +54,8 @@
#include "runtime/device/ascend/profiling/profiling_callback_register.h" #include "runtime/device/ascend/profiling/profiling_callback_register.h"
#include "backend/kernel_compiler/hccl/hccl_context.h" #include "backend/kernel_compiler/hccl/hccl_context.h"
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
#include "minddata/dataset/engine/tdt/tdt_handle.h" #include "tdt/tdt_host_interface.h"
using mindspore::dataset::TdtHandle; #include "tdt/status.h"
#endif #endif
using ge::model_runner::ModelRunner; using ge::model_runner::ModelRunner;
@ -698,10 +698,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
// Run task error, we should call TdtHostDestroy to release tdt to avoid DeviceQueueOp hostPush hung // Run task error, we should call TdtHostDestroy to release tdt to avoid DeviceQueueOp hostPush hung
// case1: cpu usage 100% cause thread/process exit, but some tdt thread remain in backend // case1: cpu usage 100% cause thread/process exit, but some tdt thread remain in backend
if (!TdtHandle::DestroyHandle()) { int32_t destory_status = tdt::TdtHostDestroy();
MS_LOG(WARNING) << "Destroy tdt channel failed."; if (destory_status != TDT_OK_CODE) {
MS_LOG(WARNING) << "Destroy tsd failed, status = " << destory_status << ".";
} else { } else {
MS_LOG(INFO) << "Destroy tdt channel success."; MS_LOG(INFO) << "Destroy tsd success.";
} }
#endif #endif
return false; return false;

@ -22,6 +22,7 @@
#include <atomic> #include <atomic>
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "utils/convert_utils_base.h" #include "utils/convert_utils_base.h"
@ -45,7 +46,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
} }
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) { if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened."; MS_LOG(DEBUG) << "TDT Dataset client is already opened.";
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
return true; return true;
} }
@ -55,8 +56,10 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
return true; return true;
} }
uint32_t rank_size = 1; unsigned int device_id;
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); unsigned int rank_size = 1;
device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto rank_size_env = common::GetEnv("RANK_SIZE"); auto rank_size_env = common::GetEnv("RANK_SIZE");
if (rank_size_env.empty()) { if (rank_size_env.empty()) {
@ -78,14 +81,14 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
} }
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); int32_t initStatus = tdt::TdtHostInit(device_id);
if (acl_handle == nullptr) { if (initStatus != TDT_OK_CODE) {
MS_LOG(EXCEPTION) << "Get acltdt handle failed"; MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
return false; return false;
} }
ms_context_ptr->acl_tdt_print = std::thread(TensorPrint(acl_handle)); ms_context_ptr->tdt_print_ = std::thread(TensorPrint());
#endif #endif
MS_LOG(INFO) << "Get the acltdt handle successful, tsd reference = " MS_LOG(INFO) << "Open and init tsd successful, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << "."; << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
return true; return true;
} }
@ -100,34 +103,28 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF); ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) { if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0); ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); int32_t stopStatus = tdt::TdtHostStop(KNpuLog);
aclError stopStatus = acltdtStopChannel(acl_handle); if (stopStatus != TDT_OK_CODE) {
if (stopStatus != ACL_SUCCESS) { MS_LOG(EXCEPTION) << "Stop tsd failed, status = " << stopStatus << ".";
MS_LOG(ERROR) << "Failed stop acl data channel for host queue "; return false;
} else {
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
} }
MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op ";
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
aclError destrodStatus = acltdtDestroyChannel(acl_handle); int32_t destroyStatus = tdt::TdtHostDestroy();
if (destrodStatus != ACL_SUCCESS) { if (destroyStatus != TDT_OK_CODE) {
MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op "; MS_LOG(EXCEPTION) << "Destroy tsd failed, status = " << destroyStatus << ".";
} else { return false;
MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op ";
} }
try { try {
if (ms_context_ptr->acl_tdt_print.joinable()) { if (ms_context_ptr->tdt_print_.joinable()) {
MS_LOG(INFO) << "join acl tdt host receive process"; MS_LOG(INFO) << "join tdt host receive process";
ms_context_ptr->acl_tdt_print.join(); ms_context_ptr->tdt_print_.join();
} }
} catch (const std::exception &e) { } catch (const std::exception &e) {
MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
} }
#endif #endif
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); auto device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto ret = rtDeviceReset(device_id); auto ret = rtDeviceReset(device_id);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]"; MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
@ -136,9 +133,10 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false); ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]"; MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
} else { } else {
MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = " MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << "."; << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
} }
return true; return true;
} }
#else #else
@ -232,7 +230,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
} else { } else {
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
} }
// Disable the global variable acc, only enable it while adding training graph in pipeline // Disable the global variable acc, only enable it whlie adding training graph in pipeline
(*ge_options)["ge.exec.variable_acc"] = "0"; (*ge_options)["ge.exec.variable_acc"] = "0";
#endif #endif
} }
@ -310,7 +308,6 @@ bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) { ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
return true; return true;
} }
(void)OpenTsd(ms_context_ptr); (void)OpenTsd(ms_context_ptr);
(void)InitGe(ms_context_ptr); (void)InitGe(ms_context_ptr);
ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true); ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);

@ -24,8 +24,8 @@
#include "utils/tensorprint_utils.h" #include "utils/tensorprint_utils.h"
#ifndef NO_DLIB #ifndef NO_DLIB
#include "acl/acl_tdt.h"
#include "tdt/tsd_client.h" #include "tdt/tsd_client.h"
#include "tdt/tdt_host_interface.h"
#include "tdt/data_common.h" #include "tdt/data_common.h"
#include "runtime/dev.h" #include "runtime/dev.h"
#endif #endif
@ -35,8 +35,8 @@
namespace mindspore { namespace mindspore {
namespace context { namespace context {
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr); bool OpenTsd(const std::shared_ptr<MsContext> &inst_context);
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force = false); bool CloseTsd(const std::shared_ptr<MsContext> &inst_context, bool force = false);
void SetHcclOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options); void SetHcclOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options); void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options); void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options);

File diff suppressed because it is too large Load Diff

@ -20,10 +20,9 @@
#include <map> #include <map>
#include "ir/dtype/type.h" #include "ir/dtype/type.h"
#ifndef NO_DLIB #ifndef NO_DLIB
#include "acl/acl_tdt.h"
#include "tdt/tsd_client.h" #include "tdt/tsd_client.h"
#include "tdt/data_common.h"
#include "tdt/tdt_host_interface.h" #include "tdt/tdt_host_interface.h"
#include "tdt/data_common.h"
#include "proto/print.pb.h" #include "proto/print.pb.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#endif #endif
@ -33,11 +32,7 @@ class TensorPrint {
TensorPrint() {} TensorPrint() {}
~TensorPrint() = default; ~TensorPrint() = default;
#ifndef NO_DLIB #ifndef NO_DLIB
explicit TensorPrint(acltdtChannelHandle *acl_handle) { acl_handle_ = acl_handle; }
void operator()(); void operator()();
private:
acltdtChannelHandle *acl_handle_ = nullptr;
#endif #endif
}; };
} // namespace mindspore } // namespace mindspore

@ -50,7 +50,6 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
} else { } else {
set_param<uint32_t>(MS_CTX_DEVICE_ID, 0); set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
} }
set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT); set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
set_param<std::string>(MS_CTX_DEVICE_TARGET, target); set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode); set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
@ -108,22 +107,4 @@ std::string MsContext::backend_policy() const {
} }
return "unknown"; return "unknown";
} }
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() {
if (acl_handle == nullptr) {
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle == nullptr) {
MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name;
return nullptr;
}
MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name;
return acl_handle;
}
return acl_handle;
}
#endif
} // namespace mindspore } // namespace mindspore

@ -24,10 +24,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#ifndef NO_DLIB
#include "acl/acl_tdt.h"
#endif
namespace mindspore { namespace mindspore {
enum MsBackendPolicy { enum MsBackendPolicy {
kMsBackendGeOnly = 0, kMsBackendGeOnly = 0,
@ -132,13 +129,11 @@ class MsContext {
std::string backend_policy() const; std::string backend_policy() const;
bool set_backend_policy(const std::string &policy); bool set_backend_policy(const std::string &policy);
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *get_acl_tdt_channel_handle();
#endif
static void device_seter(DeviceSeter device) { seter_ = device; } static void device_seter(DeviceSeter device) { seter_ = device; }
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
std::thread acl_tdt_print; std::thread tdt_print_;
template <typename T> template <typename T>
void set_param(MsCtxParam param, const T &value) { void set_param(MsCtxParam param, const T &value) {
@ -173,9 +168,6 @@ class MsContext {
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
MsBackendPolicy backend_policy_; MsBackendPolicy backend_policy_;
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = nullptr;
#endif
}; };
// set method implementation for type bool/int/uint32_t/float/std::string // set method implementation for type bool/int/uint32_t/float/std::string

@ -2698,11 +2698,10 @@ class TransferDataset(Dataset):
def parse(self, children=None): def parse(self, children=None):
total_batch = 0 total_batch = 0
device_id = context.get_context("device_id")
if hasattr(self.children[0], "__total_batch__"): if hasattr(self.children[0], "__total_batch__"):
total_batch = self.children[0].__total_batch__ total_batch = self.children[0].__total_batch__
return cde.TransferNode(children[0], self.queue_name, self.device_type, device_id, self._send_epoch_end, return cde.TransferNode(children[0], self.queue_name, self.device_type, self._send_epoch_end, total_batch,
total_batch, self._create_data_info_queue) self._create_data_info_queue)
def create_dict_iterator(self, num_epochs=-1, output_numpy=False): def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
raise RuntimeError("TransferDataset is not iterable.") raise RuntimeError("TransferDataset is not iterable.")

@ -54,20 +54,15 @@ def get_tensor(is_scalar, input_type):
if __name__ == "__main__": if __name__ == "__main__":
net = TensorPrint() net = TensorPrint()
# net(get_tensor('scalar', mindspore.bool_), get_tensor('scalar', mindspore.uint8), net(get_tensor('scalar', mindspore.bool_), get_tensor('scalar', mindspore.uint8),
# get_tensor('scalar', mindspore.int8), get_tensor('scalar', mindspore.uint16), get_tensor('scalar', mindspore.int8), get_tensor('scalar', mindspore.uint16),
# get_tensor('scalar', mindspore.int16), get_tensor('scalar', mindspore.uint32), get_tensor('scalar', mindspore.int16), get_tensor('scalar', mindspore.uint32),
# get_tensor('scalar', mindspore.int32), get_tensor('scalar', mindspore.uint64), get_tensor('scalar', mindspore.int32), get_tensor('scalar', mindspore.uint64),
# get_tensor('scalar', mindspore.int64), get_tensor('scalar', mindspore.float16), get_tensor('scalar', mindspore.int64), get_tensor('scalar', mindspore.float16),
# get_tensor('scalar', mindspore.float32), get_tensor('scalar', mindspore.float64),
# get_tensor('array', mindspore.bool_), get_tensor('array', mindspore.uint8),
# get_tensor('array', mindspore.int8), get_tensor('array', mindspore.uint16),
# get_tensor('array', mindspore.int16), get_tensor('array', mindspore.uint32),
# get_tensor('array', mindspore.int32), get_tensor('array', mindspore.uint64),
# get_tensor('array', mindspore.int64), get_tensor('array', mindspore.float16),
# get_tensor('array', mindspore.float32), get_tensor('array', mindspore.float64))
net(get_tensor('scalar', mindspore.bool_),
get_tensor('scalar', mindspore.float32), get_tensor('scalar', mindspore.float64), get_tensor('scalar', mindspore.float32), get_tensor('scalar', mindspore.float64),
get_tensor('array', mindspore.bool_), get_tensor('array', mindspore.bool_), get_tensor('array', mindspore.uint8),
get_tensor('array', mindspore.int8), get_tensor('array', mindspore.uint16),
get_tensor('array', mindspore.int16), get_tensor('array', mindspore.uint32),
get_tensor('array', mindspore.int32), get_tensor('array', mindspore.uint64),
get_tensor('array', mindspore.int64), get_tensor('array', mindspore.float16),
get_tensor('array', mindspore.float32), get_tensor('array', mindspore.float64)) get_tensor('array', mindspore.float32), get_tensor('array', mindspore.float64))

Loading…
Cancel
Save