fix no_grad argspec (#23790)

test=develop
revert-22778-infer_var_type
songyouwei 6 years ago committed by GitHub
parent 9549b78691
commit 8f63a3ecff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import decorator
import contextlib
import functools
import sys
@ -196,12 +197,12 @@ def no_grad(func=None):
return _switch_tracer_mode_guard_(is_train=False)
else:
@functools.wraps(func)
def __impl__(*args, **kwargs):
@decorator.decorator
def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__
return __impl__(func)
@signature_safe_contextmanager

@ -15,6 +15,7 @@
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import unittest
import inspect
from test_imperative_base import new_program_scope
@ -51,6 +52,14 @@ class TestTracerMode(unittest.TestCase):
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
def need_no_grad_func(a, b=1):
return a + b
decorated_func = fluid.dygraph.no_grad(need_no_grad_func)
self.assertTrue(
str(inspect.getargspec(decorated_func)) ==
str(inspect.getargspec(need_no_grad_func)))
self.assertEqual(self.tracer._train_mode, self.init_mode)
with fluid.dygraph.guard():

Loading…
Cancel
Save