|
|
|
@ -26,17 +26,17 @@ using platform::Transform;
|
|
|
|
|
template <typename T>
|
|
|
|
|
class PReluFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
explicit PReluFunctor(const T* alpha) : alpha_(alpha) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE T operator()(const T& x) const {
|
|
|
|
|
if (x > 0)
|
|
|
|
|
return x;
|
|
|
|
|
else
|
|
|
|
|
return x * alpha_;
|
|
|
|
|
return x * (*alpha_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T alpha_;
|
|
|
|
|
const T* alpha_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
@ -50,30 +50,29 @@ class PReluKernel : public framework::OpKernel {
|
|
|
|
|
const T* x_ptr = x->data<T>();
|
|
|
|
|
T* o_ptr = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto alpha_val = alpha->data<T>()[0];
|
|
|
|
|
// auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
|
|
|
|
|
auto* alpha_ptr = alpha->data<T>();
|
|
|
|
|
|
|
|
|
|
int numel = x->numel();
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_val));
|
|
|
|
|
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_ptr));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class PReluGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE T operator()(const T& out, const T& dout) const {
|
|
|
|
|
if (out > 0)
|
|
|
|
|
return dout;
|
|
|
|
|
else
|
|
|
|
|
return dout * alpha_;
|
|
|
|
|
return dout * (*alpha_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T alpha_;
|
|
|
|
|
const T* alpha_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
@ -85,7 +84,7 @@ class PReluGradKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
auto* out = context.Input<Tensor>("Out");
|
|
|
|
|
auto* alpha = context.Input<Tensor>("Alpha");
|
|
|
|
|
auto alpha_val = alpha->data<T>()[0];
|
|
|
|
|
auto* alpha_ptr = alpha->data<T>();
|
|
|
|
|
|
|
|
|
|
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* dout_ptr = dout->data<T>();
|
|
|
|
@ -94,7 +93,9 @@ class PReluGradKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
|
|
|
|
|
PReluGradFunctor<T>(alpha_val));
|
|
|
|
|
PReluGradFunctor<T>(alpha_ptr));
|
|
|
|
|
|
|
|
|
|
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|