|
|
|
@ -24,9 +24,9 @@ using Tensor = framework::Tensor;
|
|
|
|
|
using platform::Transform;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class Prelu_functor {
|
|
|
|
|
class PReluFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit Prelu_functor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE T operator()(const T& X) const {
|
|
|
|
|
if (X > 0)
|
|
|
|
@ -54,14 +54,14 @@ class PReluKernel : public framework::OpKernel {
|
|
|
|
|
int numel = X->numel();
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, X_ptr, X_ptr + numel, O_ptr, Prelu_functor<T>(alpha));
|
|
|
|
|
Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor<T>(alpha));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class Prelu_Grad_functor {
|
|
|
|
|
class PReluGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit Prelu_Grad_functor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE T operator()(const T& Out, const T& dOut) const {
|
|
|
|
|
if (Out > 0)
|
|
|
|
@ -92,7 +92,7 @@ class PReluGradKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr,
|
|
|
|
|
Prelu_Grad_functor<T>(alpha));
|
|
|
|
|
PReluGradFunctor<T>(alpha));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|