|
|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.framework as framework
|
|
|
|
|
import unittest
|
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
|
|
from test_imperative_base import new_program_scope
|
|
|
|
|
|
|
|
|
|
@ -51,6 +52,14 @@ class TestTracerMode(unittest.TestCase):
|
|
|
|
|
self.assertEqual(self.no_grad_func(1), 1)
|
|
|
|
|
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
|
|
|
|
|
|
|
|
|
|
def need_no_grad_func(a, b=1):
|
|
|
|
|
return a + b
|
|
|
|
|
|
|
|
|
|
decorated_func = fluid.dygraph.no_grad(need_no_grad_func)
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
str(inspect.getargspec(decorated_func)) ==
|
|
|
|
|
str(inspect.getargspec(need_no_grad_func)))
|
|
|
|
|
|
|
|
|
|
self.assertEqual(self.tracer._train_mode, self.init_mode)
|
|
|
|
|
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
|