MKLDNN handler cleanup (#19713)

* MKLDNN handler cleanup

* MKLDNN handler cleanup
test=develop
new_fix
Adam 6 years ago committed by Tao Luo
parent 2c30e64b2f
commit 428b2b9e17

@ -136,7 +136,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
std::vector<memory::primitive_desc> srcs_pd;
std::vector<float> scales = {1.0f, 1.0f};
const std::string key = platform::MKLDNNHandler::GetHash(
const std::string key = platform::GetHash(
src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) +
std::to_string(y->format()));

@ -72,19 +72,17 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
for (size_t i = 0; i < multi_input.size(); i++) {
platform::MKLDNNHandler::AppendKeyDims(
platform::AppendKeyDims(
&key, paddle::framework::vectorize<int>(multi_input[i]->dims()));
}
platform::MKLDNNHandler::AppendKey(&key, std::to_string(concat_axis));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Out"));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key,
std::to_string(multi_input[0]->format()));
platform::AppendKey(&key, std::to_string(concat_axis));
platform::AppendKey(&key, ctx.op().Output("Out"));
platform::AppendKey(&key, std::to_string(dt));
platform::AppendKey(&key, std::to_string(multi_input[0]->format()));
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(
&key, platform::MKLDNNHandler::ThreadIDasStr());
platform::AppendKey(&key, "-t:");
platform::AppendKey(&key, platform::ThreadIDasStr());
}
return key;
}

@ -417,7 +417,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives
std::string key;
key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey(
platform::ConvMKLDNNHandler::CreateKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_activation, fuse_residual_conn,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
@ -439,7 +439,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string key_tid = "";
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::MKLDNNHandler::ThreadIDasStr();
key_tid = "-t:" + platform::ThreadIDasStr();
}
auto prim_key = key + key_tid + "@conv_p";

@ -36,10 +36,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<int>& src_tz, const float scale_data) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(src_dt));
platform::MKLDNNHandler::AppendKeyDims(&key, src_tz);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output"));
platform::AppendKey(&key, std::to_string(src_dt));
platform::AppendKeyDims(&key, src_tz);
platform::AppendKey(&key, std::to_string(scale_data));
platform::AppendKey(&key, ctx.op().Output("Output"));
return key;
}

@ -35,10 +35,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const bool is_negative) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, src_tz);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(is_negative));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output"));
platform::AppendKeyDims(&key, src_tz);
platform::AppendKey(&key, std::to_string(scale_data));
platform::AppendKey(&key, std::to_string(is_negative));
platform::AppendKey(&key, ctx.op().Output("Output"));
return key;
}

@ -205,7 +205,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
platform::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key);
@ -276,7 +276,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Input("Out"));
platform::GetHash(softmax_tz, ctx.op().Input("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";
auto softmax_pd =

@ -179,5 +179,33 @@ inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
}
}
inline std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
inline std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
}
inline void AppendKey(std::string* key, const std::string& s) {
key->append(s);
}
inline std::string GetHash(const mkldnn::memory::dims& operand_dims,
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
inline void AppendKeyDims(std::string* key, const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
} // namespace platform
} // namespace paddle

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save