|
|
|
@ -407,6 +407,9 @@ thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
|
|
|
|
|
// - For fixed-shape, it's a null string in default.
|
|
|
|
|
// - For dynamic-shape, it's user specific.
|
|
|
|
|
thread_local std::string cur_input_shape_str = "";
|
|
|
|
|
// the cache capacity of different input shapes for MKLDNN.
|
|
|
|
|
// Default 1 means fixed input shape, not dynamic shape.
|
|
|
|
|
thread_local int cur_input_shape_cache_capacity = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
|
|
|
|
@ -414,37 +417,58 @@ size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
|
|
|
|
|
void set_cur_input_shape_str(std::string input_shape_str) {
|
|
|
|
|
cur_input_shape_str = input_shape_str;
|
|
|
|
|
}
|
|
|
|
|
std::string get_cur_input_shape_str(void) { return cur_input_shape_str; }
|
|
|
|
|
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
|
|
|
|
|
cur_input_shape_cache_capacity = input_shape_cache_capacity;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
|
|
|
|
|
|
|
|
|
|
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
auto map_it = pMap->find(cur_mkldnn_session_id);
|
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
|
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
|
|
|
|
|
<< cur_mkldnn_session_id;
|
|
|
|
|
}
|
|
|
|
|
return map_it->second->size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
|
std::shared_ptr<void> data) const {
|
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
std::shared_ptr<ShapeBlob> sBlob = nullptr;
|
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
int tid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
int sid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
|
|
|
|
|
// Find ShapeBlob for current thread
|
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
|
// Find ShapeBlob for current mkldnn session id.
|
|
|
|
|
auto map_it = pMap->find(sid);
|
|
|
|
|
|
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
|
// 1st time to set blob in current thread
|
|
|
|
|
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
|
|
|
|
|
(*pMap)[tid] = sBlob;
|
|
|
|
|
VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n";
|
|
|
|
|
(*pMap)[sid] = sBlob;
|
|
|
|
|
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
|
|
|
|
|
} else {
|
|
|
|
|
sBlob = map_it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current input shape
|
|
|
|
|
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
|
|
|
|
|
auto key_it = sBlob->find(cur_input_shape_str);
|
|
|
|
|
|
|
|
|
|
if (key_it == sBlob->end()) {
|
|
|
|
|
// In cache clearing mode, cur_input_shape_cache_capacity defines
|
|
|
|
|
// max pblob capacity
|
|
|
|
|
if ((sid == kMKLDNNSessionID_CacheClearing) &&
|
|
|
|
|
(sBlob->size() >=
|
|
|
|
|
static_cast<size_t>(cur_input_shape_cache_capacity))) {
|
|
|
|
|
VLOG(2) << "sid=" << sid
|
|
|
|
|
<< ", remove all blobs of shape: " << sBlob->begin()->first;
|
|
|
|
|
sBlob->erase(sBlob->begin()->first);
|
|
|
|
|
}
|
|
|
|
|
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
|
|
|
|
|
(*sBlob)[cur_input_shape_str] = pBlob;
|
|
|
|
|
} else {
|
|
|
|
@ -458,7 +482,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
|
} else {
|
|
|
|
|
blob_it->second = data; // set data to existing blob
|
|
|
|
|
}
|
|
|
|
|
VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n";
|
|
|
|
|
VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
|
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -469,23 +493,22 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
|
std::shared_ptr<ShapeBlob> sBlob = nullptr;
|
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
int tid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
int sid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
|
|
|
|
|
// Find ShapeBlob for current thread firstly
|
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
|
// Find ShapeBlob for current mkldnn session id firstly
|
|
|
|
|
auto map_it = pMap->find(sid);
|
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
|
VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n";
|
|
|
|
|
VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
|
|
|
|
|
sBlob = map_it->second;
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current input shape secondly
|
|
|
|
|
auto sBlob_it = sBlob->find(cur_input_shape_str);
|
|
|
|
|
if (sBlob_it == sBlob->end()) {
|
|
|
|
|
VLOG(2) << "GetBlob: tid=" << cur_input_shape_str
|
|
|
|
|
VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
|
|
|
|
|
<< ", miss input_shape_str\n";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
@ -495,11 +518,11 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
|
auto key_it = pBlob->find(name);
|
|
|
|
|
|
|
|
|
|
if (key_it == pBlob->end()) {
|
|
|
|
|
VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n";
|
|
|
|
|
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n";
|
|
|
|
|
VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
|
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
|
return key_it->second;
|
|
|
|
|
}
|
|
|
|
|