Add no_grad decorator to dygraph (#17790)

* add no_grad decorator to dygraph, test=develop

* add unittest,test=develop
dependabot/pip/python/requests-2.20.0
Zeng Jinle 6 years ago committed by GitHub
parent 53920f5e8a
commit 3a6ead24ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,20 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib
import numpy as np import numpy as np
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .tracer import Tracer from .tracer import Tracer
__all__ = ['enabled', 'guard', 'to_variable'] __all__ = [
'enabled',
'no_grad',
'guard',
'to_variable',
]
def enabled(): def enabled():
return framework.in_dygraph_mode() return framework.in_dygraph_mode()
@contextlib.contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
yield
tracer._train_mode = mode
else:
yield
def _no_grad_(func):
def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__
no_grad = wrap_decorator(_no_grad_)
@signature_safe_contextmanager @signature_safe_contextmanager
def guard(place=None): def guard(place=None):
train = framework.Program() train = framework.Program()

@ -22,6 +22,7 @@ import functools
from . import layers from . import layers
from . import framework from . import framework
from . import core from . import core
from .dygraph import base as imperative_base
__all__ = [ __all__ = [
'GradClipByValue', 'GradClipByValue',
@ -37,6 +38,7 @@ class GradClipBase(object):
def _clip(self, para_and_grad): def _clip(self, para_and_grad):
raise NotImplementedError raise NotImplementedError
@imperative_base.no_grad
def __call__(self, para_and_grad): def __call__(self, para_and_grad):
return self._clip(para_and_grad) return self._clip(para_and_grad)
@ -86,6 +88,7 @@ class GradClipByValue(GradClipBase):
""" """
@imperative_base.no_grad
def __init__(self, min_value, max_value=None): def __init__(self, min_value, max_value=None):
if min_value is None: if min_value is None:
@ -164,6 +167,7 @@ class GradClipByNorm(GradClipBase):
""" """
@imperative_base.no_grad
def __init__(self, clip_norm): def __init__(self, clip_norm):
self.clip_norm = clip_norm self.clip_norm = clip_norm
@ -243,6 +247,7 @@ class GradClipByGlobalNorm(GradClipBase):
""" """
@imperative_base.no_grad
def __init__(self, max_global_norm): def __init__(self, max_global_norm):
self.max_global_norm = layers.fill_constant( self.max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=max_global_norm) shape=[1], dtype='float32', value=max_global_norm)

@ -55,6 +55,7 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
@imperative_base.no_grad
def __init__(self, learning_rate, regularization=None, name=None): def __init__(self, learning_rate, regularization=None, name=None):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate, float) and \
@ -472,6 +473,7 @@ class Optimizer(object):
optimize_ops = self.apply_gradients(params_grads) optimize_ops = self.apply_gradients(params_grads)
return optimize_ops return optimize_ops
@imperative_base.no_grad
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,

@ -0,0 +1,48 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import unittest
class TestTracerMode(unittest.TestCase):
def setUp(self):
self.init_mode = True
def get_tracer_mode(self):
assert fluid.dygraph.enabled(), "Dygraph mode must be enabled"
@fluid.dygraph.no_grad
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a
def test_main(self):
with fluid.dygraph.guard():
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = self.init_mode
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.tracer._train_mode, self.init_mode)
class TestTracerMode2(TestTracerMode):
def setUp(self):
self.init_mode = False
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save