fix sqrt unittest. test=develop (#17440)

resnext-opt
Kaipeng Deng 6 years ago committed by GitHub
parent 977e9fcb27
commit 14f223624f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -96,8 +96,8 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
class TestSqrtDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [7, 9]
eps = 0.005
shape = [3, 7]
eps = 0.0001
dtype = np.float64
x = layers.data('x', shape, False, dtype)
@ -107,9 +107,9 @@ class TestSqrtDoubleGradCheck(unittest.TestCase):
x_arr = np.random.uniform(0.1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps, rtol=1e-2, atol=1e-2)
[x], y, x_init=x_arr, place=place, eps=eps)
def no_test_grad(self):
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]

Loading…
Cancel
Save