|
|
|
|
@ -78,10 +78,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto dims = paddle::framework::vectorize2int(x->dims());
|
|
|
|
|
|
|
|
|
|
auto src_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
auto dst_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, x->format());
|
|
|
|
|
|
|
|
|
|
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
|
|
|
|
|
mkldnn::lrn_across_channels,
|
|
|
|
|
@ -92,8 +89,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
k};
|
|
|
|
|
|
|
|
|
|
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
|
|
|
|
|
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(output_data)};
|
|
|
|
|
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
const std::string key = ctx.op().Output("Out");
|
|
|
|
|
@ -110,11 +105,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
src_memory->set_data_handle(
|
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
|
|
|
|
|
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
auto workspace_memory = insert_to_context<mkldnn::memory>(
|
|
|
|
|
key_workspace_memory, dev_ctx,
|
|
|
|
|
forward_pd->workspace_primitive_desc());
|
|
|
|
|
|
|
|
|
|
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
|
|
|
|
|
|
|
|
|
|
out->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
out->set_format(platform::GetMKLDNNFormat(dst_memory));
|
|
|
|
|
} else {
|
|
|
|
|
auto forward_pd =
|
|
|
|
|
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
|
|
|
|
|
@ -122,8 +122,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
|
|
|
|
|
auto workspace_memory =
|
|
|
|
|
mkldnn::memory{forward_pd.workspace_primitive_desc()};
|
|
|
|
|
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
|
|
|
|
|
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
|
|
|
|
|
|
|
|
|
|
out->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
out->set_format(platform::GetMKLDNNFormat(dst_memory));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|