|
|
|
@ -296,38 +296,73 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
|
|
|
|
|
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
|
|
|
|
|
p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
|
|
|
|
|
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
|
|
|
|
|
p_blobmap_.reset(new BlobMap());
|
|
|
|
|
p_mutex_.reset(new std::mutex());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// Current thread's id.
|
|
|
|
|
thread_local int cur_thread_id = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
|
|
|
|
|
int get_cur_thread_id(void) { return cur_thread_id; }
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
|
std::shared_ptr<void> data) const {
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<void>>* p;
|
|
|
|
|
p = p_blobs_.get();
|
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
int tid = platform::get_cur_thread_id();
|
|
|
|
|
|
|
|
|
|
auto it = p->find(name);
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_.get());
|
|
|
|
|
|
|
|
|
|
if (it == p->end()) {
|
|
|
|
|
(*p)[name] = data; // create new blob
|
|
|
|
|
// Find KeyBlob for current thread
|
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
|
|
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
|
// 1st time to set blob in current thread
|
|
|
|
|
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
|
|
|
|
|
(*pMap)[tid] = pBlob;
|
|
|
|
|
} else {
|
|
|
|
|
it->second = data; // set data to existing blob
|
|
|
|
|
pBlob = map_it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Find Key in found (or newly created) KeyBlob
|
|
|
|
|
auto key_it = pBlob->find(name);
|
|
|
|
|
|
|
|
|
|
if (key_it == pBlob->end()) {
|
|
|
|
|
(*pBlob)[name] = data; // create new blob
|
|
|
|
|
} else {
|
|
|
|
|
key_it->second = data; // set data to existing blob
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<void>>* p;
|
|
|
|
|
p = p_blobs_.get();
|
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
auto it = p->find(name);
|
|
|
|
|
int tid = platform::get_cur_thread_id();
|
|
|
|
|
|
|
|
|
|
if (it != p->end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_.get());
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current thread firstly
|
|
|
|
|
auto map_it = pMap->find(tid);
|
|
|
|
|
if (map_it == pMap->end()) return nullptr;
|
|
|
|
|
pBlob = map_it->second;
|
|
|
|
|
|
|
|
|
|
// Find Blob via name
|
|
|
|
|
auto key_it = pBlob->find(name);
|
|
|
|
|
|
|
|
|
|
if (key_it == pBlob->end()) return nullptr;
|
|
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
// lock will be automatically released when out of scope
|
|
|
|
|
return key_it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|