|
|
|
@ -53,6 +53,7 @@ bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, c
|
|
|
|
|
const std::string supported_data_type = "int32";
|
|
|
|
|
if (data_type != supported_data_type) {
|
|
|
|
|
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
|
|
|
|
|
invalid_data_type_ = true;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (data == nullptr) {
|
|
|
|
@ -105,12 +106,20 @@ bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void *PsDataPrefetch::data(const std::string &channel_name) const {
|
|
|
|
|
bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
|
|
|
|
|
if (invalid_data_type_) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (data_ptr == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto channel = ps_data_channel(channel_name);
|
|
|
|
|
if (channel == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
*data_ptr = nullptr;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return channel->data();
|
|
|
|
|
*data_ptr = channel->data();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
|
|
|
|
|