|
|
@ -403,42 +403,62 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
|
|
|
|
namespace {
|
|
|
|
namespace {
|
|
|
|
// Current mkldnn session id.
|
|
|
|
// Current mkldnn session id.
|
|
|
|
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
|
|
|
|
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
|
|
|
|
}
|
|
|
|
// Current data input shape string.
|
|
|
|
|
|
|
|
// - For fixed-shape, it's a null string in default.
|
|
|
|
|
|
|
|
// - For dynamic-shape, it's user specific.
|
|
|
|
|
|
|
|
thread_local std::string cur_input_shape_str = "";
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
|
|
|
|
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
|
|
|
|
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
|
|
|
|
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
|
|
|
|
|
|
|
|
void set_cur_input_shape_str(std::string input_shape_str) {
|
|
|
|
|
|
|
|
cur_input_shape_str = input_shape_str;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string get_cur_input_shape_str(void) { return cur_input_shape_str; }
|
|
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
|
|
|
|
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
|
|
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
std::shared_ptr<void> data) const {
|
|
|
|
std::shared_ptr<void> data) const {
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
|
|
|
std::shared_ptr<ShapeBlob> sBlob = nullptr;
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
int tid = platform::get_cur_mkldnn_session_id();
|
|
|
|
int tid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current thread
|
|
|
|
// Find ShapeBlob for current thread
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
|
|
|
|
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
// 1st time to set blob in current thread
|
|
|
|
// 1st time to set blob in current thread
|
|
|
|
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
|
|
|
|
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
|
|
|
|
(*pMap)[tid] = pBlob;
|
|
|
|
(*pMap)[tid] = sBlob;
|
|
|
|
|
|
|
|
VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n";
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
pBlob = map_it->second;
|
|
|
|
sBlob = map_it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Find Key in found (or newly created) KeyBlob
|
|
|
|
// Find KeyBlob for current input shape
|
|
|
|
auto key_it = pBlob->find(name);
|
|
|
|
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
|
|
|
|
|
|
|
|
auto key_it = sBlob->find(cur_input_shape_str);
|
|
|
|
|
|
|
|
|
|
|
|
if (key_it == pBlob->end()) {
|
|
|
|
if (key_it == sBlob->end()) {
|
|
|
|
(*pBlob)[name] = data; // create new blob
|
|
|
|
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
|
|
|
|
|
|
|
|
(*sBlob)[cur_input_shape_str] = pBlob;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
key_it->second = data; // set data to existing blob
|
|
|
|
pBlob = key_it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Find Blob via name
|
|
|
|
|
|
|
|
auto blob_it = pBlob->find(name);
|
|
|
|
|
|
|
|
if (blob_it == pBlob->end()) {
|
|
|
|
|
|
|
|
(*pBlob)[name] = data;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
blob_it->second = data; // set data to existing blob
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n";
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -446,22 +466,40 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
const std::string& name) const {
|
|
|
|
const std::string& name) const {
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
|
|
|
std::shared_ptr<ShapeBlob> sBlob = nullptr;
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
int tid = platform::get_cur_mkldnn_session_id();
|
|
|
|
int tid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current thread firstly
|
|
|
|
// Find ShapeBlob for current thread firstly
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
if (map_it == pMap->end()) return nullptr;
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
pBlob = map_it->second;
|
|
|
|
VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n";
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
|
|
|
|
|
|
|
|
sBlob = map_it->second;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current input shape secondly
|
|
|
|
|
|
|
|
auto sBlob_it = sBlob->find(cur_input_shape_str);
|
|
|
|
|
|
|
|
if (sBlob_it == sBlob->end()) {
|
|
|
|
|
|
|
|
VLOG(2) << "GetBlob: tid=" << cur_input_shape_str
|
|
|
|
|
|
|
|
<< ", miss input_shape_str\n";
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
pBlob = sBlob_it->second;
|
|
|
|
|
|
|
|
|
|
|
|
// Find Blob via name
|
|
|
|
// Find Blob via name
|
|
|
|
auto key_it = pBlob->find(name);
|
|
|
|
auto key_it = pBlob->find(name);
|
|
|
|
|
|
|
|
|
|
|
|
if (key_it == pBlob->end()) return nullptr;
|
|
|
|
if (key_it == pBlob->end()) {
|
|
|
|
|
|
|
|
VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n";
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n";
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
return key_it->second;
|
|
|
|
return key_it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|