fix no_grad signature (#23600)

* fix no_grad signature
test=develop

* check func name instead of doc
test=develop
revert-23830-2.0-beta
songyouwei 6 years ago committed by GitHub
parent f792d5f71b
commit a1a95f8108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib
import functools
import sys
import numpy as np
from paddle.fluid import core
@ -195,6 +196,7 @@ def no_grad(func=None):
return _switch_tracer_mode_guard_(is_train=False)
else:
@functools.wraps(func)
def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)

@ -49,6 +49,7 @@ class TestTracerMode(unittest.TestCase):
self.tracer._train_mode = self.init_mode
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
self.assertEqual(self.tracer._train_mode, self.init_mode)

Loading…
Cancel
Save