|
|
@ -73,8 +73,13 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
auto *y = ctx.Output<Tensor>("Out");
|
|
|
|
auto *y = ctx.Output<Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
|
|
const T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
|
|
|
|
T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
|
|
|
|
const T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
|
|
|
|
T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// paddle uses beta but mkldnn uses alpha for swish
|
|
|
|
|
|
|
|
if (algorithm == mkldnn::algorithm::eltwise_swish) {
|
|
|
|
|
|
|
|
std::swap(alpha, beta);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
|
|
|
|
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
|
|
|
@ -112,8 +117,13 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
|
|
|
|
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
|
|
|
|
const T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
|
|
|
|
T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
|
|
|
|
const T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
|
|
|
|
T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// paddle uses beta but mkldnn uses alpha for swish
|
|
|
|
|
|
|
|
if (algorithm == mkldnn::algorithm::eltwise_swish) {
|
|
|
|
|
|
|
|
std::swap(alpha, beta);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
|
|
|
|
auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
|
|
|
|
|
|
|
|
|
|
|
@ -162,6 +172,10 @@ template <typename T>
|
|
|
|
using ReluMKLDNNFunctor =
|
|
|
|
using ReluMKLDNNFunctor =
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
using SwishMKLDNNFunctor =
|
|
|
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
using TanhMKLDNNFunctor =
|
|
|
|
using TanhMKLDNNFunctor =
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
|
|
@ -178,6 +192,10 @@ template <typename T>
|
|
|
|
using ReluMKLDNNGradFunctor =
|
|
|
|
using ReluMKLDNNGradFunctor =
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
using SwishMKLDNNGradFunctor =
|
|
|
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
using TanhMKLDNNGradFunctor =
|
|
|
|
using TanhMKLDNNGradFunctor =
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
|
|
@ -204,6 +222,7 @@ namespace ops = paddle::operators;
|
|
|
|
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
|
|
|
|
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
|
|
|
|
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
|
|
|
|
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
|
|
|
|
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
|
|
|
|
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
|
|
|
|
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
|
|
|
|
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
|
|
|
|
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
|
|
|
|
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
|
|
|
|