Merge pull request #4372 from reyoung/feature/stable_prelu_grad_test

Stabilize prelu gradient check
update-doc-pybind
Zhuoyuan 8 years ago committed by GitHub
commit e5a3c1d2d5

@ -7,6 +7,14 @@ class PReluTest(OpTest):
def setUp(self):
self.op_type = "prelu"
x_np = np.random.normal(size=(10, 10)).astype("float32")
for pos, val in np.ndenumerate(x_np):
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
while abs(val) < 1e-3:
x_np[pos] = np.random.normal()
val = x_np[pos]
x_np_sign = np.sign(x_np)
x_np = x_np_sign * np.maximum(x_np, .005)
alpha_np = np.array([.1])

Loading…
Cancel
Save