diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 30464bbca9..a42d291318 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, if (in_format != out_format) { void* in_data = GetDataFromTensor(in, in_type); - const std::string key = - platform::CreateKey(in_tz, in_format, out_format, in_type); + std::string key = + platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type); platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx, cpu_engine, key); diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index e51d94e4b1..1eed49de78 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - CreateKey(unique_name, MKLDNNGetDataType(), Ti)), + CreateKey(dev_ctx, unique_name, MKLDNNGetDataType(), Ti)), N(N), Ti(Ti), IC(IC), OC(OC) { // Create memory key without Ti because weights, bias and h0 memories // do not depend on Ti size but primitive and input/output memory do - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType()); - } else { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType(), "-t:", - platform::ThreadIDasStr()); - } + memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); // Is it int8 kernel const bool is_INT8 = std::is_same::value; diff --git a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc index b7fd40f78f..11711bab81 100644 --- a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc @@ -109,13 +109,8 @@ class MultiGRUHandler { const std::string unique_name = ctx.OutputName("Hidden"); // Create memory key without Ti because weights, bias and h0 memories // do not depend on Ti size but primitive and input/output memory do - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType()); - } else { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType(), "-t:", - platform::ThreadIDasStr()); - } + memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); key_ = memory_key_; key_.append("T").append(std::to_string(Ti_)); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 98f368aa7a..622d6685df 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(x->dims()), unique_name)) { + platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), + unique_name)) { if (!this->isCached()) { const float epsilon = ctx.Attr("epsilon"); const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); @@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dev_ctx, dims, uniq_name)) { auto diff_dst_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); auto src_md = diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 114daaecb5..ea9b629d90 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -158,9 +158,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { // If one of the multiple inputs of concat has an input size of 0, the // actual size of the multi_input will change std::string key = platform::CreateKey( - paddle::framework::vectorize(multi_input[0]->dims()), + dev_ctx, paddle::framework::vectorize(multi_input[0]->dims()), multi_input.size(), ctx.OutputName("Out"), dt, - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); + platform::ThreadIDasStr()); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@concat_p"; const std::string key_concat_pd = key + "@concat_pd"; diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 2e6d809c98..68fe582838 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( @@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); - std::string key = platform::CreateKey( - src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter")); + std::string key = + platform::CreateKey(dev_ctx, src_tz, src_dt, + ctx.InputName("Input") + ctx.InputName("Filter")); const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; @@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // This is workaround for hacky implementation // of conv int8 mkl-dnn. Once conv fp32 and conv int8 // are merged/unified, this will disappear - std::string key_tid = ""; - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_tid = "-t:" + platform::ThreadIDasStr(); - } - - auto prim_key = key + key_tid + "@conv_p"; - auto dst_key = key + key_tid + "@dst_mem_p"; - auto src_key = key + key_tid + "@src_mem_p"; - auto weights_key = key + key_tid + "@weights_mem_p"; - auto bias_key = key + key_tid + "@bias_mem_p"; - auto user_src_key = key + key_tid + "@user_src_mem_p"; - auto user_residual_key = key + key_tid + "@user_residual_data_mem_p"; - auto src_reorder_key = key + key_tid + "@src_mem_preorder_p"; - auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p"; + auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + + auto prim_key = key_tid + "@conv_p"; + auto dst_key = key_tid + "@dst_mem_p"; + auto src_key = key_tid + "@src_mem_p"; + auto weights_key = key_tid + "@weights_mem_p"; + auto bias_key = key_tid + "@bias_mem_p"; + auto user_src_key = key_tid + "@user_src_mem_p"; + auto user_residual_key = key_tid + "@user_residual_data_mem_p"; + auto src_reorder_key = key_tid + "@src_mem_preorder_p"; + auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p"; conv_p = std::static_pointer_cast( dev_ctx.GetBlob(prim_key)); @@ -972,10 +969,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "input" and "Filter" variable // as well as attributes of primitive to be created // This name will be used as key when saving info into device context - const std::string key = platform::CreateKey( - src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); + std::string key = platform::CreateKey( + dev_ctx, src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); const std::string key_conv_pd = key + "@fwd_pd"; + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); std::vector pipeline; // Create user memory descriptors @@ -1090,8 +1088,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { mkldnn::memory::format_tag out_format = weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw : mkldnn::memory::format_tag::goihw; - const std::string key = - platform::CreateKey(weights_tz, filter_fmt, out_format, in_type); + std::string key = platform::CreateKey(dev_ctx, weights_tz, filter_fmt, + out_format, in_type); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(), in_type, dev_ctx, mkldnn_engine, diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index e9f32e7ac2..1eb90451a6 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_tz = paddle::framework::vectorize(output->dims()); // Get unique name for storing MKLDNN primitives - const std::string key = - platform::CreateKey(src_tz, ctx.OutputName("Output")); + platform::CreateKey(dev_ctx, src_tz, ctx.OutputName("Output")); std::vector pipeline; diff --git a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc index e036fd9aba..8d41b75097 100644 --- a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc @@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); MKLDNNMemoryFormat src_fmt = input->format(); - std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt, - src_tz, ctx.OutputName("Output")); + + std::string key = + platform::CreateKey(dev_ctx, src_dt, src_tz, ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 820c46c67d..c0cfbd089f 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -370,8 +370,9 @@ class FCPrimitiveFactory { void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx) { - const std::string key = - platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); + std::string key = platform::CreateKey(dev_ctx); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string weights_key = key + ctx.InputName("W"); const std::string bias_key = key + ctx.InputName("Bias"); dev_ctx.SetBlob(weights_key, weights_); @@ -541,10 +542,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, const Tensor* w, const Tensor* bias, LoDTensor* output, bool fuse_relu, bool force_fp32_output) { auto& dev_ctx = ctx.template device_context(); - const std::string prim_key = platform::CreateKey( - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(), - input->dims()[0], framework::vectorize(w->dims()), - ctx.OutputName("Out")); + std::string prim_key = platform::CreateKey( + dev_ctx, input->format(), input->dims()[0], + framework::vectorize(w->dims()), ctx.OutputName("Out")); + prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key); + constexpr bool is_int8 = std::is_same::value || std::is_same::value; bool is_bfloat16 = std::is_same::value; diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc index 22261e948a..65dcb328f2 100644 --- a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc @@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dev_ctx, dims, uniq_name)) { if (!this->isCached()) { auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); if (!is_test) { diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 1f2216cbed..92be4d19e7 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -336,9 +336,8 @@ static std::shared_ptr> GetPrimitiveFactory( const auto& out_name = ctx.OutputName("Out"); const auto& dev_ctx = ctx.template device_context(); const auto batch_size = ctx.Input("X")->dims()[0]; - - const std::string key = platform::CreateKey( - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name); + std::string key = platform::CreateKey(dev_ctx, batch_size, out_name); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key)); diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index 258b6971a0..4174d88de6 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -305,9 +305,11 @@ std::shared_ptr> GetPrimitiveFactory( const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx, const Tensor *input_x, const Tensor *input_y, const mkldnn::engine &mkldnn_engine) { - const std::string key = platform::CreateKey( - input_x->type(), framework::vectorize(input_x->dims()), input_y->type(), - framework::vectorize(input_y->dims()), ctx.OutputName("Out")); + std::string key = platform::CreateKey( + dev_ctx, input_x->type(), framework::vectorize(input_x->dims()), + input_y->type(), framework::vectorize(input_y->dims()), + ctx.OutputName("Out")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); auto prim_creator = std::static_pointer_cast>( dev_ctx.GetBlob(key)); diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 4e689f5bcc..9488a1a440 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context const std::string key = platform::CreateKey( - diff_src_tz, pooling_type, ksize, strides, paddings, + dev_ctx, diff_src_tz, pooling_type, ksize, strides, paddings, memory::data_type::f32, in_x->format(), ctx.InputName("Out")); platform::PoolingMKLDNNHandler handler( diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index 3e04e2dcf0..7a03c6ce86 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel { bool is_negative_input = ctx.Attr("is_negative_input"); bool bfloat16 = ctx.Attr("bfloat16"); - std::string key = platform::CreateKey( - platform::ThreadIDasStr(), src_tz, scale_data, scale_shift, - is_negative_input, ctx.OutputName("Output")); + std::string key = + platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift, + is_negative_input, ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index a3b078205e..aa74a45e3a 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel { float reorder_scale = scale_out / scale_in; - std::string key = - platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in, - scale_out, ctx.OutputName("Output")); + std::string key = platform::CreateKey(dev_ctx, src_tz, scale_in, scale_out, + ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 9d9e1e2d8d..3eb2e7084a 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler mkldnn::softmax_backward>( dev_ctx, mkldnn_engine, cpu_place, // Softmax may be inplace then uniq_name is no longer unique - platform::CreateKey(framework::vectorize(input->dims()), axis, - uniq_name)) { + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), + axis, uniq_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( input->dims(), output->dims(), @@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, axis, uniq_name)) { + platform::CreateKey(dev_ctx, dims, axis, uniq_name)) { auto data_softmax_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_softmax_md = diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index e1031c02be..2b6f959472 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(z->dims()), uniq_name)), + platform::CreateKey(dev_ctx, framework::vectorize(z->dims()), + uniq_name)), num_inputs_(0) { for (size_t i = 0; i < in_vars.size(); i++) { srcs_suffix_.push_back(std::string("-") + std::to_string(i)); @@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { // For in-place execution which sum does not have we need to fake it // so from oneDNN dst memory we reorder data into input if (in_place) { - const std::string reorder_key = platform::CreateKey( - framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I"); + const std::string reorder_key = + platform::CreateKey(dev_ctx, framework::vectorize(output->dims()), + ctx.OutputName("Out") + "-I"); auto& in_out = in_vars[0]->Get(); auto output_tz = framework::vectorize(output->dims()); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 28cdd8413a..feda5645b4 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto nchw_tz = paddle::framework::vectorize(input->dims()); - const std::string key = platform::CreateKey(nchw_tz, ctx.OutputName("Out")); + const std::string key = + platform::CreateKey(dev_ctx, nchw_tz, ctx.OutputName("Out")); platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx, mkldnn_engine, key); @@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto nchw_tz = paddle::framework::vectorize(out_grad->dims()); const std::string key = platform::CreateKey( - nchw_tz, ctx.OutputName(framework::GradVarName("X"))); + dev_ctx, nchw_tz, ctx.OutputName(framework::GradVarName("X"))); platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx, mkldnn_engine, key); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 620e2d41c1..8661c5e2ce 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -532,6 +532,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; } const std::string& GetKeySuffix(void) const { return key_suffix_; } + // Disable adding thread ID to the key + void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }; + bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }; + // Prevent next ResetBlobMap() void BlockNextCacheClearing(); @@ -554,6 +558,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { std::shared_ptr p_mutex_; bool block_next_cache_clearing_ = false; std::string key_suffix_; // Key identifying current Executor + bool key_attach_thread_id_ = true; }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 99044c53d2..2de08773df 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -431,11 +431,6 @@ inline void AppendKey(std::string* key, const std::vector& dims) { } } -inline unsigned int HashPointer(uintptr_t ptr) { - // Get four less meaningful digits in decimal numerals - return ptr % 1000; -} - // If MKLDNN build and CPU place then register suffix in DeviceContext inline void AttachPointerHashToMKLDNNKey(void* ptr, const platform::Place& place) { @@ -443,20 +438,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::MKLDNNDeviceContext* dev_ctx = (platform::MKLDNNDeviceContext*)pool.Get(place); - dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer( - reinterpret_cast(ptr)))); + dev_ctx->SetKeySuffix("E" + + std::to_string(reinterpret_cast(ptr))); + // When NaiveExecutor/Executor is used no info on thread id is needed in a + // key + dev_ctx->DisableThreadInfoInKey(); } } template -inline std::string CreateKey(ArgTypes&&... args) { +inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx, + ArgTypes&&... args) { std::string key; key.reserve(64); using expand_type = int[]; expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; + key += dev_ctx.GetKeySuffix(); return key; } +inline std::string ExtendKeyWithThreadInfoIfNeeded( + const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) { + return ((dev_ctx.IsThreadIdUsedInKey() == true) && + (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default)) + ? key + "-t:" + ThreadIDasStr() + : key; +} + inline std::vector> ToMkldnnPadding( const std::vector& paddings) { if (paddings.size() == 6) { diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 6976e55b23..03443996b6 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -43,16 +43,9 @@ class MKLDNNHandlerT { engine_(engine), place_(cpu_place), key_common_(base_key), + key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)), fwd_pd_(nullptr), - bwd_pd_(nullptr) { - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_ = key_common_; - } else { - key_ = key_common_ + "-t:" + ThreadIDasStr(); - } - key_ += dev_ctx.GetKeySuffix(); - } + bwd_pd_(nullptr) {} std::shared_ptr AcquireForwardPrimitive() { const std::string key_p = key_ + "@fwd_p"; @@ -306,8 +299,8 @@ class MKLDNNHandlerT { const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; platform::Place place_; - std::string key_; std::string key_common_; + std::string key_; std::shared_ptr fwd_pd_; std::shared_ptr bwd_pd_; }; @@ -317,15 +310,10 @@ class MKLDNNHandler { public: MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const std::string& base_key) - : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_ = key_common_; - } else { - key_ = key_common_ + "-t:" + ThreadIDasStr(); - } - key_ += dev_ctx.GetKeySuffix(); - } + : dev_ctx_(dev_ctx), + engine_(engine), + key_common_(base_key), + key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {} std::shared_ptr AcquireSrcMemory( const mkldnn::memory::desc& md, void* ptr) { @@ -508,8 +496,8 @@ class MKLDNNHandler { protected: const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; - std::string key_; std::string key_common_; + std::string key_; }; template @@ -524,7 +512,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey( - framework::vectorize(x->dims()), + dev_ctx, framework::vectorize(x->dims()), uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { // bradcasting combined with in-place may require auto rankdiff = x->dims().size() - y->dims().size(); @@ -627,7 +615,7 @@ class ActivationMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, "a", algorithm, unique_name)) { + platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, @@ -645,7 +633,7 @@ class ActivationMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, "a", algorithm, unique_name)) { + platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { auto diff_dst_md = platform::MKLDNNMemDesc( dims, platform::MKLDNNGetDataType(), diff_fmt); auto src_md = @@ -676,7 +664,7 @@ class LRNMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { if (!this->isCached()) { const int n = ctx.Attr("n"); @@ -712,7 +700,7 @@ class LRNMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, unique_name)) { + platform::CreateKey(dev_ctx, dims, unique_name)) { auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_md = @@ -752,7 +740,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), framework::ToMKLDNNDataType(input->type()), unique_name)) { if (!this->isCached()) { @@ -861,7 +849,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(diff_src_dims, dt, unique_name)) { + platform::CreateKey(dev_ctx, diff_src_dims, dt, unique_name)) { auto diff_dst_md = mkldnn::memory::desc( diff_dst_dims, platform::MKLDNNGetDataType(), diff_dst_fmt); auto diff_src_md =