|
|
|
@ -7,6 +7,14 @@ class PReluTest(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "prelu"
|
|
|
|
|
x_np = np.random.normal(size=(10, 10)).astype("float32")
|
|
|
|
|
|
|
|
|
|
for pos, val in np.ndenumerate(x_np):
|
|
|
|
|
# Since zero point in prelu is not differentiable, avoid randomize
|
|
|
|
|
# zero.
|
|
|
|
|
while abs(val) < 1e-3:
|
|
|
|
|
x_np[pos] = np.random.normal()
|
|
|
|
|
val = x_np[pos]
|
|
|
|
|
|
|
|
|
|
x_np_sign = np.sign(x_np)
|
|
|
|
|
x_np = x_np_sign * np.maximum(x_np, .005)
|
|
|
|
|
alpha_np = np.array([.1])
|
|
|
|
|