fix train eval set error in static mode (#29540)

revert-31562-mean
Chen Weihang 4 years ago committed by GitHub
parent b5d4a1f33d
commit c1a26e2a05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -133,8 +133,11 @@ class Layer(core.Layer):
out = mylayer(x)
"""
# global setting
framework._dygraph_tracer().train_mode()
# global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode
if in_dygraph_mode():
framework._dygraph_tracer().train_mode()
# Layer-level setting
self.training = True
for layer in self.sublayers():
@ -171,8 +174,11 @@ class Layer(core.Layer):
print(out)
"""
# global setting
framework._dygraph_tracer().eval_mode()
# global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode
if in_dygraph_mode():
framework._dygraph_tracer().eval_mode()
# Layer-level setting
self.training = False
for layer in self.sublayers():

@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase):
self.assertFalse(net.weight.trainable)
class TestLayerTrainingAttribute(unittest.TestCase):
def test_set_train_eval_in_dynamic_mode(self):
with fluid.dygraph.guard():
net = paddle.nn.Dropout()
net.train()
self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
def test_set_train_eval_in_static_mode(self):
net = paddle.nn.Dropout()
net.train()
self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()

Loading…
Cancel
Save