add input data type check for ps cache mode

pull/12165/head
lizhenyu 4 years ago
parent fd18382853
commit 6649153c49

@ -301,10 +301,9 @@ Status DeviceQueueOp::PushDataToGPU() {
} }
// Data prefetch only when PS mode enables cache. // Data prefetch only when PS mode enables cache.
if (items.size() > 0) { if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_,
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) { items[0].data_type_)) {
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
} }
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
@ -434,6 +433,11 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
if (sub_item.data_ptr_ == nullptr) { if (sub_item.data_ptr_ == nullptr) {
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed."); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
} }
if (curr_row[i] == nullptr) {
MS_LOG(ERROR) << "The pointer curr_row[" << i << "] is null";
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "TensorRow 'curr_row' contains nullptr.");
}
sub_item.data_type_ = curr_row[i]->type().ToString();
const unsigned char *column_data = curr_row[i]->GetBuffer(); const unsigned char *column_data = curr_row[i]->GetBuffer();
if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data,
static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) { static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) {

@ -55,7 +55,8 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
#if ENABLE_D #if ENABLE_D
// Data prefetch only when PS mode enables cache. // Data prefetch only when PS mode enables cache.
if (items.size() > 0) { if (items.size() > 0) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) { if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_,
items[0].tensorType_)) {
return FAILED; return FAILED;
} }
} }

@ -44,10 +44,17 @@ std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string
return iter->second; return iter->second;
} }
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
const std::string &data_type) {
if (cache_enable_ == false) { if (cache_enable_ == false) {
return true; return true;
} }
// In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
const std::string supported_data_type = "int32";
if (data_type != supported_data_type) {
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
return false;
}
if (data == nullptr) { if (data == nullptr) {
MS_LOG(WARNING) << "No data prefetch."; MS_LOG(WARNING) << "No data prefetch.";
return true; return true;

@ -37,7 +37,8 @@ class EXPORT PsDataPrefetch {
EXPORT bool cache_enable() const { return cache_enable_; } EXPORT bool cache_enable() const { return cache_enable_; }
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size); EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
const std::string &data_type);
EXPORT bool FinalizeData(const std::string &channel_name); EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize(); EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const; EXPORT void *data(const std::string &channel_name) const;

@ -34,6 +34,7 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
struct DataItemGpu { struct DataItemGpu {
int32_t worker_id_; int32_t worker_id_;
std::string data_type_;
size_t data_len_; size_t data_len_;
void *data_ptr_; void *data_ptr_;
}; };

Loading…
Cancel
Save