|
|
|
@ -67,7 +67,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
mid->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
const int n = ctx.Attr<int>("n");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha");
|
|
|
|
|
// MKL-DNN implements LRN in a caffe way:
|
|
|
|
|
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
|
|
|
|
|
// Where sum of squares is divided by size of normalization window
|
|
|
|
|
// this is not the case for PaddlePaddle LRN.
|
|
|
|
|
// Hence we need to compensate for this diffrence by
|
|
|
|
|
// multipliing alpha by size of window(n)
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
@ -78,10 +84,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto dims = paddle::framework::vectorize2int(x->dims());
|
|
|
|
|
|
|
|
|
|
auto src_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
auto dst_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, x->format());
|
|
|
|
|
|
|
|
|
|
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
|
|
|
|
|
mkldnn::lrn_across_channels,
|
|
|
|
@ -92,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
k};
|
|
|
|
|
|
|
|
|
|
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
|
|
|
|
|
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(output_data)};
|
|
|
|
|
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
const std::string key = ctx.op().Output("Out");
|
|
|
|
@ -110,11 +111,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
src_memory->set_data_handle(
|
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
|
|
|
|
|
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
auto workspace_memory = insert_to_context<mkldnn::memory>(
|
|
|
|
|
key_workspace_memory, dev_ctx,
|
|
|
|
|
forward_pd->workspace_primitive_desc());
|
|
|
|
|
|
|
|
|
|
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
|
|
|
|
|
|
|
|
|
|
out->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
out->set_format(platform::GetMKLDNNFormat(dst_memory));
|
|
|
|
|
} else {
|
|
|
|
|
auto forward_pd =
|
|
|
|
|
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
|
|
|
|
@ -122,8 +128,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
|
|
|
|
|
auto workspace_memory =
|
|
|
|
|
mkldnn::memory{forward_pd.workspace_primitive_desc()};
|
|
|
|
|
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
|
|
|
|
|
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
|
|
|
|
|
|
|
|
|
|
out->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
out->set_format(platform::GetMKLDNNFormat(dst_memory));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -151,7 +162,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
|
|
|
|
|
|
|
|
|
|
const int n = ctx.Attr<int>("n");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
|
|
|
|
|