add assert raises in the test_retain_graph UT. (#25983)

revert-24895-update_cub
Zhen Wang 5 years ago committed by GitHub
parent 7165f48409
commit e656ca4783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,7 +20,7 @@ import unittest
paddle.disable_static()
SEED = 2020
np.random.seed(SEED)
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
class Generator(fluid.dygraph.Layer):
@ -90,7 +90,7 @@ class TestRetainGraph(unittest.TestCase):
else:
return 0.0, None
def test_retain(self):
def run_retain(self, need_retain):
g = Generator()
d = Discriminator()
@ -117,7 +117,7 @@ class TestRetainGraph(unittest.TestCase):
d, realA, fakeB, lambda_gp=10.0)
loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty
loss_d.backward(retain_graph=True)
loss_d.backward(retain_graph=need_retain)
optim_d.minimize(loss_d)
optim_g.clear_gradients()
@ -130,6 +130,11 @@ class TestRetainGraph(unittest.TestCase):
loss_g.backward()
optim_g.minimize(loss_g)
def test_retain(self):
self.run_retain(need_retain=True)
self.assertRaises(
fluid.core.EnforceNotMet, self.run_retain, need_retain=False)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save