|
|
|
@ -13,25 +13,27 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""cell"""
|
|
|
|
|
import inspect
|
|
|
|
|
import time
|
|
|
|
|
import gc
|
|
|
|
|
import inspect
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
import numpy
|
|
|
|
|
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
|
|
|
from .. import context
|
|
|
|
|
from .._c_expression import init_backend, Cell_
|
|
|
|
|
from .._checkparam import Validator
|
|
|
|
|
from ..common import dtype as mstype
|
|
|
|
|
from ..common.api import _executor, _pynative_exec
|
|
|
|
|
from .._checkparam import Validator
|
|
|
|
|
from ..common.parameter import Parameter, ParameterTuple
|
|
|
|
|
from .._c_expression import init_backend, Cell_
|
|
|
|
|
from ..ops.primitive import Primitive
|
|
|
|
|
from ..ops.operations import HookBackward
|
|
|
|
|
from ..common.tensor import Tensor, MetaTensor
|
|
|
|
|
from ..ops.functional import cast
|
|
|
|
|
from ..ops.operations import HookBackward
|
|
|
|
|
from ..ops.primitive import Primitive
|
|
|
|
|
from ..parallel._tensor import _load_tensor_by_layout
|
|
|
|
|
from ..common.tensor import Tensor, MetaTensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Cell(Cell_):
|
|
|
|
@ -393,10 +395,10 @@ class Cell(Cell_):
|
|
|
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
|
|
|
|
if name in self.__dict__:
|
|
|
|
|
if self.__dict__[name] is not None:
|
|
|
|
|
raise TypeError("Expected type is not in (Parameter, Cell), but got Parameter.")
|
|
|
|
|
raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
|
|
|
|
|
del self.__dict__[name]
|
|
|
|
|
if cells and name in cells:
|
|
|
|
|
raise TypeError("Expected type is Cell, but got Parameter.")
|
|
|
|
|
raise TypeError("The type of value should be Cell, but got Parameter.")
|
|
|
|
|
self.insert_param_to_cell(name, value)
|
|
|
|
|
elif isinstance(value, ParameterTuple):
|
|
|
|
|
if params is None:
|
|
|
|
@ -417,7 +419,7 @@ class Cell(Cell_):
|
|
|
|
|
if name in self.__dict__:
|
|
|
|
|
del self.__dict__[name]
|
|
|
|
|
if params and name in params:
|
|
|
|
|
raise TypeError("Expected type is Parameter, but got Cell.")
|
|
|
|
|
raise TypeError("The type of value should be Parameter, but got Cell.")
|
|
|
|
|
if self._auto_prefix:
|
|
|
|
|
value.update_parameters_name(name + '.')
|
|
|
|
|
cells[name] = value
|
|
|
|
@ -427,12 +429,13 @@ class Cell(Cell_):
|
|
|
|
|
if isinstance(value, Tensor) and self._params[name] is not None:
|
|
|
|
|
self._params[name].set_data(value)
|
|
|
|
|
elif value is not None:
|
|
|
|
|
raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value)))
|
|
|
|
|
raise TypeError(f"The type of value should be Parameter or ParameterTuple, "
|
|
|
|
|
f"but got {type(value).__name__}.")
|
|
|
|
|
else:
|
|
|
|
|
self.insert_param_to_cell(name, None)
|
|
|
|
|
elif cells and name in cells:
|
|
|
|
|
if value is not None:
|
|
|
|
|
raise TypeError("Expected type is cell, but got {}.".format(type(value)))
|
|
|
|
|
raise TypeError(f"The type of value should be cell, but got {type(value).__name__}.")
|
|
|
|
|
self._cells[name] = None
|
|
|
|
|
elif isinstance(value, Tensor):
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
@ -705,6 +708,7 @@ class Cell(Cell_):
|
|
|
|
|
new_p = param.init_data(layout, set_sliced=set_sliced)
|
|
|
|
|
replace[param] = new_p
|
|
|
|
|
return new_p
|
|
|
|
|
|
|
|
|
|
# replace all original usage.
|
|
|
|
|
cells = self.cells_and_names()
|
|
|
|
|
for _, cell in cells:
|
|
|
|
|