|
|
|
@ -38,10 +38,9 @@ class PReluKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto dim = x->dims();
|
|
|
|
|
int index = 0;
|
|
|
|
|
int i = 0;
|
|
|
|
|
int temp = 0;
|
|
|
|
|
if (mode == "channel") {
|
|
|
|
|
int temp = numel / (dim[0] * dim[1]);
|
|
|
|
|
for (i = 0; i < numel; i++) {
|
|
|
|
|
temp = numel / (dim[0] * dim[1]);
|
|
|
|
|
index = (i / temp) % dim[1];
|
|
|
|
|
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
|
|
|
|
|
}
|
|
|
|
|