|
|
|
@ -222,35 +222,35 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using ReluMkldnnFunctor =
|
|
|
|
|
using ReluMKLDNNFunctor =
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using TanhMkldnnFunctor =
|
|
|
|
|
using TanhMKLDNNFunctor =
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using SqrtMkldnnFunctor =
|
|
|
|
|
using SqrtMKLDNNFunctor =
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using AbsMkldnnFunctor =
|
|
|
|
|
using AbsMKLDNNFunctor =
|
|
|
|
|
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using ReluMkldnnGradFunctor =
|
|
|
|
|
using ReluMKLDNNGradFunctor =
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using TanhMkldnnGradFunctor =
|
|
|
|
|
using TanhMKLDNNGradFunctor =
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using SqrtMkldnnGradFunctor =
|
|
|
|
|
using SqrtMKLDNNGradFunctor =
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using AbsMkldnnGradFunctor =
|
|
|
|
|
using AbsMKLDNNGradFunctor =
|
|
|
|
|
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
@ -265,9 +265,9 @@ namespace ops = paddle::operators;
|
|
|
|
|
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
|
|
|
|
|
__macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor); \
|
|
|
|
|
__macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor);
|
|
|
|
|
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
|
|
|
|
|
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
|
|
|
|
|
|
|
|
|
|
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
|
|
|
|
|