|
|
|
@ -13,8 +13,8 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "mkldnn.hpp"
|
|
|
|
|
#include "mkldnn_activation_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/activation_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/mkldnn_activation_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -46,14 +46,18 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
|
|
|
|
|
|
|
|
|
|
// create memory description
|
|
|
|
|
auto data_md = src_tz.size() == 2
|
|
|
|
|
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nc)
|
|
|
|
|
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nc)
|
|
|
|
|
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
// create memory primitives
|
|
|
|
|
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data);
|
|
|
|
|
auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data);
|
|
|
|
|
auto src_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(src_data)));
|
|
|
|
|
auto dst_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(dst_data)));
|
|
|
|
|
|
|
|
|
|
auto forward_desc = mkldnn::eltwise_forward::desc(
|
|
|
|
|
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
|
|
|
|
@ -94,17 +98,20 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
|
|
|
|
|
|
|
|
|
|
// create memory description
|
|
|
|
|
auto data_md = src_tz.size() == 2
|
|
|
|
|
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nc)
|
|
|
|
|
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nc)
|
|
|
|
|
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
// create memory primitives
|
|
|
|
|
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src);
|
|
|
|
|
auto src_memory = mkldnn::memory(
|
|
|
|
|
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
|
|
|
|
|
auto diff_src_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src);
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(diff_src)));
|
|
|
|
|
auto diff_dst_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst);
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(diff_dst)));
|
|
|
|
|
|
|
|
|
|
auto backward_desc =
|
|
|
|
|
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
|
|
|
|
|