|
|
|
@ -408,12 +408,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
|
|
|
|
|
std::shared_ptr<platform::ConvMKLDNNHandler> handler;
|
|
|
|
|
|
|
|
|
|
auto prim_key = key + "@conv_p";
|
|
|
|
|
auto dst_key = key + "@dst_mem_p";
|
|
|
|
|
auto src_key = key + "@src_mem_p";
|
|
|
|
|
auto user_src_key = key + "@user_src_mem_p";
|
|
|
|
|
auto src_reorder_key = key + "@src_mem_preorder_p";
|
|
|
|
|
auto residual_reorder_key = key + "@residual_data_mem_preorder_p";
|
|
|
|
|
// This is workaround for hacky implementation
|
|
|
|
|
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
|
|
|
|
|
// are merged/unified, this will disappear
|
|
|
|
|
std::string key_tid = "";
|
|
|
|
|
if (platform::get_cur_mkldnn_session_id() ==
|
|
|
|
|
platform::kMKLDNNSessionID_Default) {
|
|
|
|
|
key_tid = "-t:" + platform::MKLDNNHandler::ThreadIDasStr();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto prim_key = key + key_tid + "@conv_p";
|
|
|
|
|
auto dst_key = key + key_tid + "@dst_mem_p";
|
|
|
|
|
auto src_key = key + key_tid + "@src_mem_p";
|
|
|
|
|
auto user_src_key = key + key_tid + "@user_src_mem_p";
|
|
|
|
|
auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
|
|
|
|
|
auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";
|
|
|
|
|
|
|
|
|
|
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
|
|
|
|
|
dev_ctx.GetBlob(prim_key));
|
|
|
|
|