|
|
|
@ -54,8 +54,8 @@ class PReluKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
int numel = x->numel();
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_ptr));
|
|
|
|
|
Transform(context.device_context(), x_ptr, x_ptr + numel, o_ptr,
|
|
|
|
|
PReluFunctor<T>(alpha_ptr));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -91,9 +91,8 @@ class PReluGradKernel : public framework::OpKernel {
|
|
|
|
|
const T* out_ptr = out->data<T>();
|
|
|
|
|
int numel = dx->numel();
|
|
|
|
|
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
|
|
|
|
|
PReluGradFunctor<T>(alpha_ptr));
|
|
|
|
|
Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr,
|
|
|
|
|
dx_ptr, PReluGradFunctor<T>(alpha_ptr));
|
|
|
|
|
|
|
|
|
|
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
|
|
|
|
|
}
|
|
|
|
|