|
|
|
@ -22,18 +22,6 @@ namespace operators {
|
|
|
|
|
using paddle::framework::Tensor;
|
|
|
|
|
using paddle::platform::MKLDNNDeviceContext;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
mkldnn::algorithm LRNAlgorithm(const paddle::framework::ExecutionContext& ctx) {
|
|
|
|
|
mkldnn::algorithm algorithm = mkldnn::lrn_across_channels;
|
|
|
|
|
|
|
|
|
|
std::string algorithm_str = ctx.Attr<std::string>("algorithm");
|
|
|
|
|
if (algorithm_str == "WITHIN_CHANNEL") {
|
|
|
|
|
algorithm = mkldnn::lrn_within_channel;
|
|
|
|
|
}
|
|
|
|
|
return algorithm;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -64,8 +52,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
|
|
|
|
|
auto algorithm = LRNAlgorithm(ctx);
|
|
|
|
|
|
|
|
|
|
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
|
|
|
|
e_mid = e_mid.constant(k);
|
|
|
|
|
|
|
|
|
@ -77,8 +63,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto dst_md = paddle::platform::MKLDNNMemDesc(
|
|
|
|
|
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
auto forward_desc = mkldnn::lrn_forward::desc{
|
|
|
|
|
mkldnn::prop_kind::forward, algorithm, src_md, n, alpha, beta, k};
|
|
|
|
|
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
|
|
|
|
|
mkldnn::lrn_across_channels,
|
|
|
|
|
src_md,
|
|
|
|
|
n,
|
|
|
|
|
alpha,
|
|
|
|
|
beta,
|
|
|
|
|
k};
|
|
|
|
|
|
|
|
|
|
auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
forward_desc, mkldnn_engine);
|
|
|
|
@ -154,10 +145,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(x_grad_data)};
|
|
|
|
|
|
|
|
|
|
auto algorithm = LRNAlgorithm(ctx);
|
|
|
|
|
|
|
|
|
|
auto backward_desc = mkldnn::lrn_backward::desc{
|
|
|
|
|
algorithm, src_md, diff_src_md, n, alpha, beta, k};
|
|
|
|
|
mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k};
|
|
|
|
|
|
|
|
|
|
auto forward_pd = dev_ctx.GetBlob(key_pd);
|
|
|
|
|
|
|
|
|
|