|
|
|
@ -26,6 +26,7 @@ from ..common.parameter import Parameter, ParameterTuple
|
|
|
|
|
from .._c_expression import init_backend
|
|
|
|
|
from ..ops.primitive import Primitive
|
|
|
|
|
from ..ops.operations import HookBackward
|
|
|
|
|
from ..ops.functional import cast
|
|
|
|
|
from ..parallel._tensor import _load_tensor_by_layout
|
|
|
|
|
from ..common.tensor import Tensor
|
|
|
|
|
|
|
|
|
@ -60,6 +61,7 @@ class Cell:
|
|
|
|
|
def __init__(self, auto_prefix=True, flags=None):
|
|
|
|
|
self._params = OrderedDict()
|
|
|
|
|
self._cells = OrderedDict()
|
|
|
|
|
self._params_list = OrderedDict()
|
|
|
|
|
self.training = False
|
|
|
|
|
self.requires_grad = False
|
|
|
|
|
self.pynative = False
|
|
|
|
@ -188,11 +190,22 @@ class Cell:
|
|
|
|
|
if '_params' in self.__dict__:
|
|
|
|
|
params = self.__dict__['_params']
|
|
|
|
|
if name in params:
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
|
return self.cast_param(params[name])
|
|
|
|
|
return params[name]
|
|
|
|
|
if '_cells' in self.__dict__:
|
|
|
|
|
cells = self.__dict__['_cells']
|
|
|
|
|
if name in cells:
|
|
|
|
|
return cells[name]
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
|
|
|
|
|
params_list = self.__dict__['_params_list']
|
|
|
|
|
if name in params_list:
|
|
|
|
|
para_list = params_list[name]
|
|
|
|
|
cast_list = list()
|
|
|
|
|
for para in para_list:
|
|
|
|
|
cast_list.append(self.cast_param(para))
|
|
|
|
|
para_list = ParameterTuple(cast_list)
|
|
|
|
|
return para_list
|
|
|
|
|
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
@ -225,10 +238,21 @@ class Cell:
|
|
|
|
|
cell.set_grad(True)
|
|
|
|
|
else:
|
|
|
|
|
_pynative_exec.set_grad_flag(False)
|
|
|
|
|
cast_inputs = list()
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
|
|
|
|
for item in inputs:
|
|
|
|
|
cast_inputs.append(cast(item, mstype.float16))
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
|
|
|
|
for item in inputs:
|
|
|
|
|
cast_inputs.append(cast(item, mstype.float32))
|
|
|
|
|
if cast_inputs:
|
|
|
|
|
cast_inputs = tuple(cast_inputs)
|
|
|
|
|
else:
|
|
|
|
|
cast_inputs = inputs
|
|
|
|
|
if self.enable_hook:
|
|
|
|
|
output = self._hook_construct(*inputs)
|
|
|
|
|
output = self._hook_construct(*cast_inputs)
|
|
|
|
|
else:
|
|
|
|
|
output = self.construct(*inputs)
|
|
|
|
|
output = self.construct(*cast_inputs)
|
|
|
|
|
if isinstance(output, Parameter):
|
|
|
|
|
output = output.data
|
|
|
|
|
if self.requires_grad is True:
|
|
|
|
@ -241,6 +265,7 @@ class Cell:
|
|
|
|
|
def __setattr__(self, name, value):
|
|
|
|
|
cells = self.__dict__.get('_cells')
|
|
|
|
|
params = self.__dict__.get('_params')
|
|
|
|
|
params_list = self.__dict__.get('_params_list')
|
|
|
|
|
if isinstance(value, Parameter):
|
|
|
|
|
if params is None:
|
|
|
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
|
|
|
@ -256,7 +281,12 @@ class Cell:
|
|
|
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
|
|
|
|
for item in value:
|
|
|
|
|
self.insert_param_to_cell(item.name, item, check_name=False)
|
|
|
|
|
object.__setattr__(self, name, value)
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
|
if name in self.__dict__:
|
|
|
|
|
del self.__dict__[name]
|
|
|
|
|
params_list[name] = value
|
|
|
|
|
else:
|
|
|
|
|
object.__setattr__(self, name, value)
|
|
|
|
|
elif isinstance(value, Cell):
|
|
|
|
|
if cells is None:
|
|
|
|
|
raise AttributeError("Can not assign cells before Cell.__init__() call.")
|
|
|
|
@ -458,6 +488,19 @@ class Cell:
|
|
|
|
|
raise TypeError("The type of parameter should be 'Parameter' if not None.")
|
|
|
|
|
self._params[param_name] = param
|
|
|
|
|
|
|
|
|
|
def cast_param(self, param):
|
|
|
|
|
"""
|
|
|
|
|
Cast parameter according to auto mix precison level in pynative mode.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
param (Parameter): The parameter to cast.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
|
|
|
|
return cast(param, mstype.float16)
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
|
|
|
|
return cast(param, mstype.float32)
|
|
|
|
|
return param
|
|
|
|
|
|
|
|
|
|
def insert_child_to_cell(self, child_name, child):
|
|
|
|
|
"""
|
|
|
|
|
Adds a child cell to the current cell.
|
|
|
|
|