|
|
|
@ -76,6 +76,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
// paddle uses beta but mkldnn uses alpha for swish
|
|
|
|
|
if (algorithm == mkldnn::algorithm::eltwise_swish) {
|
|
|
|
|
std::swap(alpha, beta);
|
|
|
|
|
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
|
|
|
|
|
alpha = ctx.Attr<T>("threshold");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
@ -119,6 +121,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
// paddle uses beta but mkldnn uses alpha for swish
|
|
|
|
|
if (algorithm == mkldnn::algorithm::eltwise_swish) {
|
|
|
|
|
std::swap(alpha, beta);
|
|
|
|
|
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
|
|
|
|
|
alpha = ctx.Attr<T>("threshold");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
|
|
|
|
@ -192,6 +196,10 @@ template <typename T>
|
|
|
|
|
using ReluMKLDNNFunctor =
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using Relu6MKLDNNFunctor =
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using SwishMKLDNNFunctor =
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;
|
|
|
|
@ -216,6 +224,10 @@ template <typename T>
|
|
|
|
|
using ReluMKLDNNGradFunctor =
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using Relu6MKLDNNGradFunctor =
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using SwishMKLDNNGradFunctor =
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;
|
|
|
|
@ -249,6 +261,7 @@ namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
|
|
|
|
|
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
|
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
|
|
|
|
|
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
|
__macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); \
|
|
|
|
|
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
|
|
|
|
|