|
|
|
@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
|
|
|
|
|
p_mutex_.reset(new std::mutex());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// Current mkldnn session id.
|
|
|
|
|
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
|
|
|
|
|
// Current data input shape string.
|
|
|
|
|
// - For fixed-shape, it's a null string in default.
|
|
|
|
|
// - For dynamic-shape, it's user specific.
|
|
|
|
|
thread_local std::string cur_input_shape_str = "";
|
|
|
|
|
// the cache capacity of different input shapes for MKLDNN.
|
|
|
|
|
// Default 1 means fixed input shape, not dynamic shape.
|
|
|
|
|
thread_local int cur_input_shape_cache_capacity = 1;
|
|
|
|
|
// Recently registered data_format. This is needed to
|
|
|
|
|
// know for converting MKL-DNN Tensor to non MKL-DNN
|
|
|
|
|
thread_local paddle::framework::DataLayout cur_paddle_data_layout =
|
|
|
|
|
paddle::framework::DataLayout::kNCHW;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
|
|
|
|
|
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
|
|
|
|
|
void set_cur_input_shape_str(std::string input_shape_str) {
|
|
|
|
|
MKLDNNDeviceContextThreadLocals::Body::Body() {
|
|
|
|
|
cur_mkldnn_session_id = kMKLDNNSessionID_Default;
|
|
|
|
|
cur_input_shape_str = "";
|
|
|
|
|
cur_input_shape_cache_capacity = 1;
|
|
|
|
|
cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
|
|
|
|
|
size_t sid) {
|
|
|
|
|
cur_mkldnn_session_id = sid;
|
|
|
|
|
}
|
|
|
|
|
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
|
|
|
|
|
return cur_mkldnn_session_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
|
|
|
|
|
std::string input_shape_str) {
|
|
|
|
|
cur_input_shape_str = input_shape_str;
|
|
|
|
|
}
|
|
|
|
|
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
|
|
|
|
|
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
|
|
|
|
|
int input_shape_cache_capacity) {
|
|
|
|
|
cur_input_shape_cache_capacity = input_shape_cache_capacity;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void set_cur_paddle_data_layout(framework::DataLayout dl) {
|
|
|
|
|
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
|
|
|
|
|
framework::DataLayout dl) {
|
|
|
|
|
cur_paddle_data_layout = dl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DataLayout get_cur_paddle_data_layout(void) {
|
|
|
|
|
framework::DataLayout
|
|
|
|
|
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
|
|
|
|
|
return cur_paddle_data_layout;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
|
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
auto map_it = pMap->find(cur_mkldnn_session_id);
|
|
|
|
|
auto map_it = pMap->find(tls().cur_mkldnn_session_id);
|
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
|
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
|
|
|
|
|
<< cur_mkldnn_session_id;
|
|
|
|
|
<< tls().cur_mkldnn_session_id;
|
|
|
|
|
}
|
|
|
|
|
return map_it->second->size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
|
std::shared_ptr<void> data) const {
|
|
|
|
|
BlobPtr_t<void> data) const {
|
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
std::shared_ptr<ShapeBlob> sBlob = nullptr;
|
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
BlobPtr_t<ShapeBlob> sBlob = nullptr;
|
|
|
|
|
BlobPtr_t<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
int sid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
int sid = tls().get_cur_mkldnn_session_id();
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
|
|
|
|
|
|
|
|
|
|
// Find ShapeBlob for current mkldnn session id.
|
|
|
|
|
auto map_it = pMap->find(sid);
|
|
|
|
|
|
|
|
|
|
if (map_it == pMap->end()) {
|
|
|
|
|
// 1st time to set blob in current thread
|
|
|
|
|
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
|
|
|
|
|
sBlob = std::make_shared<ShapeBlob>();
|
|
|
|
|
(*pMap)[sid] = sBlob;
|
|
|
|
|
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
|
|
|
|
|
} else {
|
|
|
|
@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current input shape
|
|
|
|
|
auto key_it = sBlob->find(cur_input_shape_str);
|
|
|
|
|
auto key_it = sBlob->find(tls().cur_input_shape_str);
|
|
|
|
|
|
|
|
|
|
if (key_it == sBlob->end()) {
|
|
|
|
|
// In cache clearing mode, cur_input_shape_cache_capacity defines
|
|
|
|
|
// max pblob capacity
|
|
|
|
|
if ((static_cast<size_t>(sid) == kMKLDNNSessionID_CacheClearing) &&
|
|
|
|
|
if ((static_cast<size_t>(sid) ==
|
|
|
|
|
MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
|
|
|
|
|
sBlob->size() &&
|
|
|
|
|
(sBlob->size() >=
|
|
|
|
|
static_cast<size_t>(cur_input_shape_cache_capacity))) {
|
|
|
|
|
static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
|
|
|
|
|
VLOG(2) << "sid=" << sid
|
|
|
|
|
<< ", remove all blobs of shape: " << sBlob->begin()->first;
|
|
|
|
|
sBlob->erase(sBlob->begin()->first);
|
|
|
|
|
}
|
|
|
|
|
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
|
|
|
|
|
(*sBlob)[cur_input_shape_str] = pBlob;
|
|
|
|
|
pBlob = std::make_shared<KeyBlob>();
|
|
|
|
|
(*sBlob)[tls().cur_input_shape_str] = pBlob;
|
|
|
|
|
} else {
|
|
|
|
|
pBlob = key_it->second;
|
|
|
|
|
}
|
|
|
|
@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
|
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
BlobMap* pMap = p_blobmap_.get();
|
|
|
|
|
std::shared_ptr<ShapeBlob> sBlob = nullptr;
|
|
|
|
|
std::shared_ptr<KeyBlob> pBlob = nullptr;
|
|
|
|
|
BlobPtr_t<ShapeBlob> sBlob = nullptr;
|
|
|
|
|
BlobPtr_t<KeyBlob> pBlob = nullptr;
|
|
|
|
|
|
|
|
|
|
int sid = platform::get_cur_mkldnn_session_id();
|
|
|
|
|
int sid = tls().get_cur_mkldnn_session_id();
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(*p_mutex_);
|
|
|
|
|
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
|
|
|
|
|
|
|
|
|
|
// Find ShapeBlob for current mkldnn session id firstly
|
|
|
|
|
auto map_it = pMap->find(sid);
|
|
|
|
@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
|
|
|
|
|
sBlob = map_it->second;
|
|
|
|
|
|
|
|
|
|
// Find KeyBlob for current input shape secondly
|
|
|
|
|
auto sBlob_it = sBlob->find(cur_input_shape_str);
|
|
|
|
|
auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
|
|
|
|
|
if (sBlob_it == sBlob->end()) {
|
|
|
|
|
VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
|
|
|
|
|
VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
|
|
|
|
|
<< ", miss input_shape_str\n";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|