|
|
|
@ -36,6 +36,14 @@ std::shared_ptr<T> insert_to_context(const std::string& key,
|
|
|
|
|
|
|
|
|
|
return p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
void run_primitive(Args&&... args) {
|
|
|
|
|
auto forward_op = mkldnn::lrn_forward{args...};
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {forward_op};
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -87,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(output_data)};
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<mkldnn::lrn_forward> forward_op = nullptr;
|
|
|
|
|
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
const std::string key = ctx.op().Output("Out");
|
|
|
|
|
const std::string key_src_memory = key + "@lrn_src_memory";
|
|
|
|
@ -108,9 +114,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
key_workspace_memory, dev_ctx,
|
|
|
|
|
forward_pd->workspace_primitive_desc());
|
|
|
|
|
|
|
|
|
|
forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory,
|
|
|
|
|
*workspace_memory, dst_memory});
|
|
|
|
|
|
|
|
|
|
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
|
|
|
|
|
} else {
|
|
|
|
|
auto forward_pd =
|
|
|
|
|
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
|
|
|
|
@ -119,12 +123,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto workspace_memory =
|
|
|
|
|
mkldnn::memory{forward_pd.workspace_primitive_desc()};
|
|
|
|
|
|
|
|
|
|
forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory,
|
|
|
|
|
workspace_memory, dst_memory});
|
|
|
|
|
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {*forward_op};
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -136,6 +136,9 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
"MKLDNN LRN must use float data.");
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"MKLDNN LRN must use CPUPlace.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
!ctx.Attr<bool>("is_test"),
|
|
|
|
|
"is_test attribute should be set to False in training phase.");
|
|
|
|
|
|
|
|
|
|
auto x = ctx.Input<Tensor>("X");
|
|
|
|
|
|
|
|
|
|