|
|
|
@ -16,7 +16,7 @@ from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import collections
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from .wrapped_decorator import contextmanager
|
|
|
|
|
from .wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import traceback
|
|
|
|
@ -111,7 +111,7 @@ class NameScope(object):
|
|
|
|
|
_name_scope = NameScope()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def name_scope(prefix=None):
|
|
|
|
|
"""
|
|
|
|
|
Generate hierarchical name prefix for the operators.
|
|
|
|
@ -1775,7 +1775,7 @@ class Program(object):
|
|
|
|
|
def set_op_role_var(self, var_name):
|
|
|
|
|
self._op_role_var = [var_name]
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def _optimized_guard(self, param_and_grads):
|
|
|
|
|
"""
|
|
|
|
|
A with guard to set :code:`Optimization` :code:`OpRole` and
|
|
|
|
@ -1805,7 +1805,7 @@ class Program(object):
|
|
|
|
|
self._op_role_var = tmp_var
|
|
|
|
|
self._current_role = tmp_role
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def _lr_schedule_guard(self, is_with_opt=False):
|
|
|
|
|
"""
|
|
|
|
|
A with guard to set :code:`LRSched` :code:`OpRole` and
|
|
|
|
@ -2459,7 +2459,7 @@ def switch_startup_program(program):
|
|
|
|
|
return prev_program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def program_guard(main_program, startup_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Change the global main program and startup program with `with` statement.
|
|
|
|
@ -2524,7 +2524,7 @@ def _get_var(name, program=None):
|
|
|
|
|
return program.global_block().var(name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def _imperative_guard(tracer):
|
|
|
|
|
global _imperative_tracer_
|
|
|
|
|
tmp_trace = _imperative_tracer_
|
|
|
|
@ -2535,7 +2535,7 @@ def _imperative_guard(tracer):
|
|
|
|
|
_imperative_tracer_ = tmp_trace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def _imperative_place_guard(place):
|
|
|
|
|
global _imperative_current_expected_place_
|
|
|
|
|
tmp_place = _imperative_current_expected_place_
|
|
|
|
|