|
|
|
@ -14,7 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/tensor.h"
|
|
|
|
|
#include "paddle/fluid/operators/lrn_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_reuse.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -22,30 +22,6 @@ namespace operators {
|
|
|
|
|
using paddle::framework::Tensor;
|
|
|
|
|
using paddle::platform::MKLDNNDeviceContext;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
template <typename T, typename... Args>
|
|
|
|
|
std::shared_ptr<T> insert_to_context(const std::string& key,
|
|
|
|
|
const MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
Args&&... args) {
|
|
|
|
|
auto p = std::static_pointer_cast<T, void>(dev_ctx.GetBlob(key));
|
|
|
|
|
|
|
|
|
|
if (!p) {
|
|
|
|
|
p = std::make_shared<T>(args...);
|
|
|
|
|
dev_ctx.SetBlob(key, std::static_pointer_cast<void, T>(p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
void run_primitive(Args&&... args) {
|
|
|
|
|
auto forward_op = mkldnn::lrn_forward{args...};
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {forward_op};
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -76,66 +52,42 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
|
|
|
|
|
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
|
|
|
|
e_mid = e_mid.constant(k);
|
|
|
|
|
|
|
|
|
|
auto dims = paddle::framework::vectorize2int(x->dims());
|
|
|
|
|
|
|
|
|
|
auto src_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, x->format());
|
|
|
|
|
|
|
|
|
|
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
|
|
|
|
|
mkldnn::lrn_across_channels,
|
|
|
|
|
src_md,
|
|
|
|
|
n,
|
|
|
|
|
alpha,
|
|
|
|
|
beta,
|
|
|
|
|
k};
|
|
|
|
|
|
|
|
|
|
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
|
|
|
|
|
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
const std::string key = ctx.op().Output("Out");
|
|
|
|
|
const std::string key_src_memory = key + "@lrn_src_memory";
|
|
|
|
|
const std::string key_pd = key + "@lrn_pd";
|
|
|
|
|
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
|
|
|
|
|
|
|
|
|
|
auto forward_pd = insert_to_context<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
key_pd, dev_ctx, forward_desc, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
auto src_memory = insert_to_context<mkldnn::memory>(
|
|
|
|
|
key_src_memory, dev_ctx, src_memory_pd);
|
|
|
|
|
|
|
|
|
|
src_memory->set_data_handle(
|
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
|
|
|
|
|
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
auto workspace_memory = insert_to_context<mkldnn::memory>(
|
|
|
|
|
key_workspace_memory, dev_ctx,
|
|
|
|
|
forward_pd->workspace_primitive_desc());
|
|
|
|
|
|
|
|
|
|
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
|
|
|
|
|
|
|
|
|
|
out->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
out->set_format(platform::GetMKLDNNFormat(dst_memory));
|
|
|
|
|
} else {
|
|
|
|
|
auto forward_pd =
|
|
|
|
|
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
|
|
|
|
|
auto src_memory = mkldnn::memory{
|
|
|
|
|
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
|
|
|
|
|
auto workspace_memory =
|
|
|
|
|
mkldnn::memory{forward_pd.workspace_primitive_desc()};
|
|
|
|
|
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
|
|
|
|
|
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
|
|
|
|
|
|
|
|
|
|
out->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
out->set_format(platform::GetMKLDNNFormat(dst_memory));
|
|
|
|
|
}
|
|
|
|
|
// Format and dims are assumed to be the same for dst and src
|
|
|
|
|
auto md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), x->format());
|
|
|
|
|
|
|
|
|
|
const std::string key = platform::LRNMKLDNNHandler::GetHash(
|
|
|
|
|
dims, n, alpha, beta, k, x->format(), ctx.op().Output("Out"));
|
|
|
|
|
|
|
|
|
|
platform::LRNMKLDNNHandler handler(ctx.Attr<bool>("is_test"), dev_ctx,
|
|
|
|
|
mkldnn_engine, key);
|
|
|
|
|
auto src_memory =
|
|
|
|
|
handler.AcquireSrcMemory(md, platform::to_void_cast<T>(input_data));
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Hide getting PD inside of handler for all Acquire API
|
|
|
|
|
handler.AcquireLRNPrimitiveDescriptor(md, n, alpha, beta, k);
|
|
|
|
|
|
|
|
|
|
auto dst_memory =
|
|
|
|
|
handler.AcquireDstMemory(md, platform::to_void_cast<T>(output_data));
|
|
|
|
|
|
|
|
|
|
auto lrn_p = handler.AcquireLRN(dst_memory, src_memory);
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {*lrn_p};
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
auto output_format =
|
|
|
|
|
(mkldnn::memory::format)dst_memory->get_primitive_desc()
|
|
|
|
|
.desc()
|
|
|
|
|
.data.format;
|
|
|
|
|
|
|
|
|
|
out->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
out->set_format(output_format);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -156,11 +108,6 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
|
|
const std::string key = ctx.op().Input("Out");
|
|
|
|
|
const std::string key_src_memory = key + "@lrn_src_memory";
|
|
|
|
|
const std::string key_pd = key + "@lrn_pd";
|
|
|
|
|
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
|
|
|
|
|
|
|
|
|
|
const int n = ctx.Attr<int>("n");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
@ -174,42 +121,46 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto dims = paddle::framework::vectorize2int(x->dims());
|
|
|
|
|
|
|
|
|
|
auto src_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
const std::string key = platform::LRNMKLDNNHandler::GetHash(
|
|
|
|
|
dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out"));
|
|
|
|
|
|
|
|
|
|
auto diff_src_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
platform::LRNMKLDNNHandler handler(false, dev_ctx, mkldnn_engine, key);
|
|
|
|
|
|
|
|
|
|
auto diff_dst_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
auto src_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), x->format());
|
|
|
|
|
|
|
|
|
|
auto diff_dst_memory =
|
|
|
|
|
mkldnn::memory{{diff_dst_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(const_cast<float*>(out_grad_data))};
|
|
|
|
|
// diff_dst and diff_src layouts are assumed to be the same
|
|
|
|
|
auto diff_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
|
|
|
|
|
|
|
|
|
|
auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(x_grad_data)};
|
|
|
|
|
auto workspace = handler.AcquireWorkspaceMemory();
|
|
|
|
|
|
|
|
|
|
auto backward_desc = mkldnn::lrn_backward::desc{
|
|
|
|
|
mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k};
|
|
|
|
|
auto diff_dst_memory = handler.AcquireDiffDstMemory(
|
|
|
|
|
diff_md, platform::to_void_cast<T>(out_grad_data));
|
|
|
|
|
|
|
|
|
|
auto forward_pd = dev_ctx.GetBlob(key_pd);
|
|
|
|
|
auto diff_src_memory = handler.AcquireDiffSrcMemory(
|
|
|
|
|
diff_md, platform::to_void_cast<T>(x_grad_data));
|
|
|
|
|
|
|
|
|
|
auto backward_pd = mkldnn::lrn_backward::primitive_desc{
|
|
|
|
|
backward_desc, mkldnn_engine,
|
|
|
|
|
*static_cast<mkldnn::lrn_forward::primitive_desc*>(forward_pd.get())};
|
|
|
|
|
auto src_memory = handler.AcquireSrcMemory(
|
|
|
|
|
src_md, platform::to_void_cast<T>(x->data<T>()));
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<void> workspace_memory =
|
|
|
|
|
dev_ctx.GetBlob(key_workspace_memory);
|
|
|
|
|
// TODO(jczaja): Hide this call inside Handler
|
|
|
|
|
handler.AcquireLRNBackwardPrimitiveDescriptor(src_md, diff_md, n, alpha,
|
|
|
|
|
beta, k);
|
|
|
|
|
|
|
|
|
|
auto src_memory = dev_ctx.GetBlob(key_src_memory);
|
|
|
|
|
auto backward_op = mkldnn::lrn_backward{
|
|
|
|
|
backward_pd, *static_cast<mkldnn::memory*>(src_memory.get()),
|
|
|
|
|
diff_dst_memory, *static_cast<mkldnn::memory*>(workspace_memory.get()),
|
|
|
|
|
diff_src_memory};
|
|
|
|
|
auto lrn_bwd = handler.AcquireLRNBackward(src_memory, diff_dst_memory,
|
|
|
|
|
workspace, diff_src_memory);
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {backward_op};
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline = {*lrn_bwd};
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
auto output_format =
|
|
|
|
|
(mkldnn::memory::format)diff_src_memory->get_primitive_desc()
|
|
|
|
|
.desc()
|
|
|
|
|
.data.format;
|
|
|
|
|
|
|
|
|
|
x_grad->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
x_grad->set_format(output_format);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|