|
|
|
@ -13,8 +13,8 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "mkldnn.hpp"
|
|
|
|
|
#include "mkldnn_activation_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/activation_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/mkldnn_activation_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -40,18 +40,24 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
|
|
|
|
|
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
// get memory dim
|
|
|
|
|
PADDLE_ENFORCE(src->dims().size() == 4,
|
|
|
|
|
"Input dim must be with 4, i.e. NCHW");
|
|
|
|
|
PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
|
|
|
|
|
"Input dim must be with 2 or 4");
|
|
|
|
|
std::vector<int> src_tz = framework::vectorize2int(src->dims());
|
|
|
|
|
|
|
|
|
|
// create memory description
|
|
|
|
|
// TODO(kbinias-intel): support more formats
|
|
|
|
|
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
auto data_md = src_tz.size() == 2
|
|
|
|
|
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nc)
|
|
|
|
|
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
// create memory primitives
|
|
|
|
|
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data);
|
|
|
|
|
auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data);
|
|
|
|
|
auto src_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(src_data)));
|
|
|
|
|
auto dst_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(dst_data)));
|
|
|
|
|
|
|
|
|
|
auto forward_desc = mkldnn::eltwise_forward::desc(
|
|
|
|
|
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
|
|
|
|
@ -91,15 +97,21 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
|
|
|
|
|
std::vector<int> src_tz = framework::vectorize2int(x->dims());
|
|
|
|
|
|
|
|
|
|
// create memory description
|
|
|
|
|
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
auto data_md = src_tz.size() == 2
|
|
|
|
|
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nc)
|
|
|
|
|
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
// create memory primitives
|
|
|
|
|
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src);
|
|
|
|
|
auto src_memory = mkldnn::memory(
|
|
|
|
|
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
|
|
|
|
|
auto diff_src_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src);
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(diff_src)));
|
|
|
|
|
auto diff_dst_memory =
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst);
|
|
|
|
|
mkldnn::memory({data_md, mkldnn_engine},
|
|
|
|
|
static_cast<void *>(const_cast<float *>(diff_dst)));
|
|
|
|
|
|
|
|
|
|
auto backward_desc =
|
|
|
|
|
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
|
|
|
|
|