|
|
|
@ -28,33 +28,35 @@ class PReluFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE T operator()(const T& X) const {
|
|
|
|
|
if (X > 0)
|
|
|
|
|
return X;
|
|
|
|
|
HOSTDEVICE T operator()(const T& x) const {
|
|
|
|
|
if (x > 0)
|
|
|
|
|
return x;
|
|
|
|
|
else
|
|
|
|
|
return X * alpha_;
|
|
|
|
|
return x * alpha_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T alpha_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class PReluKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* X = context.Input<Tensor>("X");
|
|
|
|
|
auto* Out = context.Output<Tensor>("Out");
|
|
|
|
|
auto* x = context.Input<Tensor>("X");
|
|
|
|
|
auto* alpha = context.Input<Tensor>("Alpha");
|
|
|
|
|
auto* out = context.Output<Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
const T* X_ptr = X->data<T>();
|
|
|
|
|
T* O_ptr = Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* x_ptr = x->data<T>();
|
|
|
|
|
T* o_ptr = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
|
|
|
|
|
auto alpha_val = alpha->data<T>()[0];
|
|
|
|
|
// auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
|
|
|
|
|
|
|
|
|
|
int numel = X->numel();
|
|
|
|
|
int numel = x->numel();
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor<T>(alpha));
|
|
|
|
|
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_val));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -63,36 +65,36 @@ class PReluGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE T operator()(const T& Out, const T& dOut) const {
|
|
|
|
|
if (Out > 0)
|
|
|
|
|
return dOut;
|
|
|
|
|
HOSTDEVICE T operator()(const T& out, const T& dout) const {
|
|
|
|
|
if (out > 0)
|
|
|
|
|
return dout;
|
|
|
|
|
else
|
|
|
|
|
return dOut * alpha_;
|
|
|
|
|
return dout * alpha_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T alpha_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class PReluGradKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dO = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
auto* Out = context.Input<Tensor>("Out");
|
|
|
|
|
auto* out = context.Input<Tensor>("Out");
|
|
|
|
|
auto* alpha = context.Input<Tensor>("Alpha");
|
|
|
|
|
auto alpha_val = alpha->data<T>()[0];
|
|
|
|
|
|
|
|
|
|
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
|
|
|
|
|
|
|
|
|
|
T* dX_ptr = dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* dO_ptr = dO->data<T>();
|
|
|
|
|
const T* O_ptr = Out->data<T>();
|
|
|
|
|
int numel = dX->numel();
|
|
|
|
|
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* dout_ptr = dout->data<T>();
|
|
|
|
|
const T* out_ptr = out->data<T>();
|
|
|
|
|
int numel = dx->numel();
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr,
|
|
|
|
|
PReluGradFunctor<T>(alpha));
|
|
|
|
|
Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
|
|
|
|
|
PReluGradFunctor<T>(alpha_val));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|