|
|
|
@ -32,7 +32,7 @@ class PReluKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* o_ptr = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
const T* alpha_ptr = alpha->data<T>();
|
|
|
|
|
std::string mode = context.Attr<std::string>("mode");
|
|
|
|
|
auto& mode = context.Attr<std::string>("mode");
|
|
|
|
|
|
|
|
|
|
int numel = x->numel();
|
|
|
|
|
auto dim = x->dims();
|
|
|
|
@ -99,6 +99,8 @@ class PReluGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
index = 0;
|
|
|
|
|
if (dalpha) {
|
|
|
|
|
T* dalpha_ptr = dalpha->mutable_data<T>(context.GetPlace());
|
|
|
|
|
memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel());
|
|
|
|
|
|
|
|
|
|
if (mode == "channel") {
|
|
|
|
|
for (i = 0; i < numel; i++) {
|
|
|
|
|
temp = numel / (dim[0] * dim[1]);
|
|
|
|
|