|
|
|
@ -75,20 +75,20 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix()
|
|
|
|
|
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
|
|
|
|
|
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
|
|
|
|
|
|
|
|
|
|
_imperative_tracer_ = None
|
|
|
|
|
_imperative_current_expected_place_ = None
|
|
|
|
|
_dygraph_tracer_ = None
|
|
|
|
|
_dygraph_current_expected_place_ = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _in_imperative_mode():
|
|
|
|
|
return _imperative_tracer_ is not None
|
|
|
|
|
def _in_dygraph_mode():
|
|
|
|
|
return _dygraph_tracer_ is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _imperative_tracer():
|
|
|
|
|
return _imperative_tracer_
|
|
|
|
|
def _dygraph_tracer():
|
|
|
|
|
return _dygraph_tracer_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _current_expected_place():
|
|
|
|
|
return _imperative_current_expected_place_
|
|
|
|
|
return _dygraph_current_expected_place_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cpu_num():
|
|
|
|
@ -396,7 +396,7 @@ class Variable(object):
|
|
|
|
|
if not isinstance(dtype, core.VarDesc.VarType):
|
|
|
|
|
dtype = convert_np_dtype_to_dtype_(dtype)
|
|
|
|
|
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
# record vars in tracer rather than blocks
|
|
|
|
|
self._ivar = kwargs.get("ivar", None)
|
|
|
|
|
if not self._ivar:
|
|
|
|
@ -406,7 +406,7 @@ class Variable(object):
|
|
|
|
|
_current_expected_place(), stop_gradient, True
|
|
|
|
|
if persistable else False)
|
|
|
|
|
if persistable:
|
|
|
|
|
_imperative_tracer().trace_var(name, self)
|
|
|
|
|
_dygraph_tracer().trace_var(name, self)
|
|
|
|
|
else:
|
|
|
|
|
self.error_clip = error_clip
|
|
|
|
|
|
|
|
|
@ -515,8 +515,8 @@ class Variable(object):
|
|
|
|
|
Returns:
|
|
|
|
|
str: The debug string.
|
|
|
|
|
"""
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
# TODO(panyx0718): add more imperative debug info.
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
# TODO(panyx0718): add more dygraph debug info.
|
|
|
|
|
return 'name %s, dtype: %s shape: %s' % (self.name, self.dtype,
|
|
|
|
|
self.shape)
|
|
|
|
|
|
|
|
|
@ -548,42 +548,42 @@ class Variable(object):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _stop_gradient(self):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self._ivar.stop_gradient
|
|
|
|
|
else:
|
|
|
|
|
return self.stop_gradient
|
|
|
|
|
|
|
|
|
|
@_stop_gradient.setter
|
|
|
|
|
def _stop_gradient(self, s):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
self._ivar.stop_gradient = s
|
|
|
|
|
else:
|
|
|
|
|
self.stop_gradient = s
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def persistable(self):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self._ivar.persistable
|
|
|
|
|
else:
|
|
|
|
|
return self.desc.persistable()
|
|
|
|
|
|
|
|
|
|
@persistable.setter
|
|
|
|
|
def persistable(self, p):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self._ivar.persistable
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_persistable(p)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def name(self):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self._ivar.name
|
|
|
|
|
else:
|
|
|
|
|
return cpt.to_text(self.desc.name())
|
|
|
|
|
|
|
|
|
|
@name.setter
|
|
|
|
|
def name(self, new_name):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
self._ivar.name = new_name
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_name(new_name)
|
|
|
|
@ -591,26 +591,26 @@ class Variable(object):
|
|
|
|
|
@property
|
|
|
|
|
def shape(self):
|
|
|
|
|
# convert to tuple, make it as same as numpy API.
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self._ivar.shape
|
|
|
|
|
else:
|
|
|
|
|
return tuple(self.desc.shape())
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dtype(self):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self._ivar.dtype
|
|
|
|
|
else:
|
|
|
|
|
return self.desc.dtype()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def lod_level(self):
|
|
|
|
|
# TODO(minqiyang): Support lod_level in imperative mode
|
|
|
|
|
# TODO(minqiyang): Support lod_level in dygraph mode
|
|
|
|
|
return self.desc.lod_level()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type(self):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self._ivar.dtype
|
|
|
|
|
else:
|
|
|
|
|
return self.desc.type()
|
|
|
|
@ -918,7 +918,7 @@ class Operator(object):
|
|
|
|
|
inputs=None,
|
|
|
|
|
outputs=None,
|
|
|
|
|
attrs=None):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
if type is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"`type` to initialized an Operator can not be None.")
|
|
|
|
@ -1037,7 +1037,7 @@ class Operator(object):
|
|
|
|
|
for arg in out_args:
|
|
|
|
|
out_arg_names.append(cpt.to_text(arg.name))
|
|
|
|
|
# TODO(minqiyang): could we remove variable's op in static mode?
|
|
|
|
|
if not _in_imperative_mode():
|
|
|
|
|
if not _in_dygraph_mode():
|
|
|
|
|
arg.op = self
|
|
|
|
|
self.desc.set_output(out_proto.name, out_arg_names)
|
|
|
|
|
|
|
|
|
@ -1083,7 +1083,7 @@ class Operator(object):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type(self):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
return self.iop.type
|
|
|
|
|
else:
|
|
|
|
|
return self.desc.type()
|
|
|
|
@ -1626,7 +1626,7 @@ class Block(object):
|
|
|
|
|
Returns:
|
|
|
|
|
Operator: the append Operator.
|
|
|
|
|
"""
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
op = Operator(
|
|
|
|
|
block=self,
|
|
|
|
|
desc=None,
|
|
|
|
@ -1638,9 +1638,8 @@ class Block(object):
|
|
|
|
|
# record ops in tracer rather than blocks
|
|
|
|
|
#
|
|
|
|
|
# TODO(minqiyang): add op stop_gradient support in static mode too.
|
|
|
|
|
# currently, we only support stop_gradient in imperative mode.
|
|
|
|
|
_imperative_tracer().trace_op(op,
|
|
|
|
|
kwargs.get("stop_gradient", False))
|
|
|
|
|
# currently, we only support stop_gradient in dygraph mode.
|
|
|
|
|
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
else:
|
|
|
|
|
op_desc = self.desc.append_op()
|
|
|
|
|
op = Operator(
|
|
|
|
@ -1699,7 +1698,7 @@ class Block(object):
|
|
|
|
|
return self.ops[start:end]
|
|
|
|
|
|
|
|
|
|
def _prepend_op(self, *args, **kwargs):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
if _in_dygraph_mode():
|
|
|
|
|
op = Operator(
|
|
|
|
|
self,
|
|
|
|
|
None,
|
|
|
|
@ -1707,8 +1706,7 @@ class Block(object):
|
|
|
|
|
inputs=kwargs.get("inputs", None),
|
|
|
|
|
outputs=kwargs.get("outputs", None),
|
|
|
|
|
attrs=kwargs.get("attrs", None))
|
|
|
|
|
_imperative_tracer().trace_op(op,
|
|
|
|
|
kwargs.get("stop_gradient", False))
|
|
|
|
|
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
else:
|
|
|
|
|
op_desc = self.desc._prepend_op()
|
|
|
|
|
op = Operator(
|
|
|
|
@ -3541,22 +3539,22 @@ def _get_var(name, program=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def _imperative_guard(tracer):
|
|
|
|
|
global _imperative_tracer_
|
|
|
|
|
tmp_trace = _imperative_tracer_
|
|
|
|
|
_imperative_tracer_ = tracer
|
|
|
|
|
def _dygraph_guard(tracer):
|
|
|
|
|
global _dygraph_tracer_
|
|
|
|
|
tmp_trace = _dygraph_tracer_
|
|
|
|
|
_dygraph_tracer_ = tracer
|
|
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
_imperative_tracer_ = tmp_trace
|
|
|
|
|
_dygraph_tracer_ = tmp_trace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def _imperative_place_guard(place):
|
|
|
|
|
global _imperative_current_expected_place_
|
|
|
|
|
tmp_place = _imperative_current_expected_place_
|
|
|
|
|
_imperative_current_expected_place_ = place
|
|
|
|
|
def _dygraph_place_guard(place):
|
|
|
|
|
global _dygraph_current_expected_place_
|
|
|
|
|
tmp_place = _dygraph_current_expected_place_
|
|
|
|
|
_dygraph_current_expected_place_ = place
|
|
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
_imperative_current_expected_place_ = tmp_place
|
|
|
|
|
_dygraph_current_expected_place_ = tmp_place
|
|
|
|
|