|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.framework as framework
|
|
|
|
|
import unittest
|
|
|
|
|
from test_imperative_base import new_program_scope
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTracerMode(unittest.TestCase):
|
|
|
|
@ -29,6 +30,18 @@ class TestTracerMode(unittest.TestCase):
|
|
|
|
|
self.assertEqual(self.tracer._train_mode, False)
|
|
|
|
|
return a
|
|
|
|
|
|
|
|
|
|
@fluid.dygraph.base._not_support
|
|
|
|
|
def not_support_func(self):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def check_not_support_rlt(self, ans):
|
|
|
|
|
try:
|
|
|
|
|
rlt = self.not_support_func()
|
|
|
|
|
except AssertionError:
|
|
|
|
|
rlt = False
|
|
|
|
|
finally:
|
|
|
|
|
self.assertEqual(rlt, ans)
|
|
|
|
|
|
|
|
|
|
def test_main(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
self.tracer = framework._dygraph_tracer()
|
|
|
|
@ -38,6 +51,12 @@ class TestTracerMode(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
self.assertEqual(self.tracer._train_mode, self.init_mode)
|
|
|
|
|
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
self.check_not_support_rlt(False)
|
|
|
|
|
|
|
|
|
|
with new_program_scope():
|
|
|
|
|
self.check_not_support_rlt(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTracerMode2(TestTracerMode):
|
|
|
|
|
def setUp(self):
|
|
|
|
|