|
|
|
@ -20,7 +20,7 @@ import unittest
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
SEED = 2020
|
|
|
|
|
np.random.seed(SEED)
|
|
|
|
|
fluid.default_main_program().random_seed = SEED
|
|
|
|
|
paddle.manual_seed(SEED)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Generator(fluid.dygraph.Layer):
|
|
|
|
@ -90,7 +90,7 @@ class TestRetainGraph(unittest.TestCase):
|
|
|
|
|
else:
|
|
|
|
|
return 0.0, None
|
|
|
|
|
|
|
|
|
|
def test_retain(self):
|
|
|
|
|
def run_retain(self, need_retain):
|
|
|
|
|
g = Generator()
|
|
|
|
|
d = Discriminator()
|
|
|
|
|
|
|
|
|
@ -117,7 +117,7 @@ class TestRetainGraph(unittest.TestCase):
|
|
|
|
|
d, realA, fakeB, lambda_gp=10.0)
|
|
|
|
|
loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty
|
|
|
|
|
|
|
|
|
|
loss_d.backward(retain_graph=True)
|
|
|
|
|
loss_d.backward(retain_graph=need_retain)
|
|
|
|
|
optim_d.minimize(loss_d)
|
|
|
|
|
|
|
|
|
|
optim_g.clear_gradients()
|
|
|
|
@ -130,6 +130,11 @@ class TestRetainGraph(unittest.TestCase):
|
|
|
|
|
loss_g.backward()
|
|
|
|
|
optim_g.minimize(loss_g)
|
|
|
|
|
|
|
|
|
|
def test_retain(self):
|
|
|
|
|
self.run_retain(need_retain=True)
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
fluid.core.EnforceNotMet, self.run_retain, need_retain=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|