|
|
|
@ -22,6 +22,22 @@ namespace operators {
|
|
|
|
|
using paddle::framework::Tensor;
|
|
|
|
|
using paddle::platform::MKLDNNDeviceContext;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
template <typename T, typename... Args>
|
|
|
|
|
std::shared_ptr<T> insert_to_context(const std::string& key,
|
|
|
|
|
const MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
Args&&... args) {
|
|
|
|
|
auto p = std::static_pointer_cast<T, void>(dev_ctx.GetBlob(key));
|
|
|
|
|
|
|
|
|
|
if (!p) {
|
|
|
|
|
p = std::make_shared<T>(args...);
|
|
|
|
|
dev_ctx.SetBlob(key, std::static_pointer_cast<void, T>(p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return p;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -42,15 +58,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto output_data = out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
mid->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
const std::string key = ctx.op().Output("Out");
|
|
|
|
|
const std::string key_src_memory = key + "@lrn_src_memory";
|
|
|
|
|
const std::string key_pd = key + "@lrn_pd";
|
|
|
|
|
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
|
|
|
|
|
|
|
|
|
|
const int n = ctx.Attr<int>("n");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha");
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
|
|
|
|
|
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
|
|
|
|
e_mid = e_mid.constant(k);
|
|
|
|
@ -71,28 +83,47 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
beta,
|
|
|
|
|
k};
|
|
|
|
|
|
|
|
|
|
auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
forward_desc, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_pd, forward_pd);
|
|
|
|
|
|
|
|
|
|
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
|
|
|
|
|
auto src_memory = std::make_shared<mkldnn::memory>(
|
|
|
|
|
src_memory_pd, static_cast<void*>(const_cast<float*>(input_data)));
|
|
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_src_memory, src_memory);
|
|
|
|
|
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(output_data)};
|
|
|
|
|
|
|
|
|
|
auto workspace_md = forward_pd->workspace_primitive_desc();
|
|
|
|
|
auto workspace_memory = std::make_shared<mkldnn::memory>(workspace_md);
|
|
|
|
|
std::unique_ptr<mkldnn::lrn_forward> forward_op = nullptr;
|
|
|
|
|
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
const std::string key = ctx.op().Output("Out");
|
|
|
|
|
const std::string key_src_memory = key + "@lrn_src_memory";
|
|
|
|
|
const std::string key_pd = key + "@lrn_pd";
|
|
|
|
|
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
|
|
|
|
|
|
|
|
|
|
auto forward_pd = insert_to_context<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
key_pd, dev_ctx, forward_desc, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
auto src_memory = insert_to_context<mkldnn::memory>(
|
|
|
|
|
key_src_memory, dev_ctx, src_memory_pd);
|
|
|
|
|
|
|
|
|
|
src_memory->set_data_handle(
|
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
|
|
|
|
|
auto workspace_memory = insert_to_context<mkldnn::memory>(
|
|
|
|
|
key_workspace_memory, dev_ctx,
|
|
|
|
|
forward_pd->workspace_primitive_desc());
|
|
|
|
|
|
|
|
|
|
forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory,
|
|
|
|
|
*workspace_memory, dst_memory});
|
|
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_workspace_memory, workspace_memory);
|
|
|
|
|
} else {
|
|
|
|
|
auto forward_pd =
|
|
|
|
|
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
|
|
|
|
|
auto src_memory = mkldnn::memory{
|
|
|
|
|
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
|
|
|
|
|
auto workspace_memory =
|
|
|
|
|
mkldnn::memory{forward_pd.workspace_primitive_desc()};
|
|
|
|
|
|
|
|
|
|
auto forward_op = mkldnn::lrn_forward{*forward_pd, *src_memory,
|
|
|
|
|
*workspace_memory, dst_memory};
|
|
|
|
|
forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory,
|
|
|
|
|
workspace_memory, dst_memory});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {forward_op};
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {*forward_op};
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|