|
|
|
|
@ -15,25 +15,22 @@
|
|
|
|
|
|
|
|
|
|
"""Parameter for cell."""
|
|
|
|
|
from copy import copy
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from .._c_expression import ParamValue
|
|
|
|
|
from . import dtype as mstype
|
|
|
|
|
from .initializer import initializer, Initializer
|
|
|
|
|
from .tensor import Tensor, MetaTensor
|
|
|
|
|
from .._checkparam import _check_str_by_regular
|
|
|
|
|
from ..parallel._tensor import _get_slice_index
|
|
|
|
|
from ..parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
|
|
|
|
|
__all__ = ['Parameter', 'ParameterTuple']
|
|
|
|
|
|
|
|
|
|
PARAMETER_NAME_DEFAULT = "Parameter"
|
|
|
|
|
PARAMETER_NAME_PREFIX_MAX_LEN = 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_type(x):
|
|
|
|
|
"""Check input data type"""
|
|
|
|
|
if not isinstance(x, Parameter):
|
|
|
|
|
raise ValueError("Should be `Parameter` collection.")
|
|
|
|
|
return True
|
|
|
|
|
def _is_in_parallel_mode():
|
|
|
|
|
"""Get parallel mode."""
|
|
|
|
|
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Parameter(MetaTensor):
|
|
|
|
|
@ -42,10 +39,10 @@ class Parameter(MetaTensor):
|
|
|
|
|
|
|
|
|
|
After initialized `Parameter` is a subtype of `Tensor`.
|
|
|
|
|
|
|
|
|
|
In graph mode, if init `Parameter` by a `Initializer`, the type of Parameter will be a `MetaTensor`
|
|
|
|
|
not a `Tensor`. `MetaTensor` only save the shape type info of a tensor with no memory usage. The shape
|
|
|
|
|
can be change while compile for auto-parallel. Call `init_data` will return a Tensor Parameter with
|
|
|
|
|
initialized data.
|
|
|
|
|
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
|
|
|
|
|
a `Initializer`, the type of Parameter will be a `MetaTensor` not a `Tensor`. `MetaTensor`
|
|
|
|
|
only save the shape type info of a tensor with no memory usage. The shape can be change while
|
|
|
|
|
compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Each parameter of Cell is represented by Parameter class.
|
|
|
|
|
@ -67,7 +64,7 @@ class Parameter(MetaTensor):
|
|
|
|
|
input_class.__init__(obj, *class_init_args)
|
|
|
|
|
# it's better to make the Initializer a kind of metatensor.
|
|
|
|
|
obj.init_mode = None
|
|
|
|
|
if isinstance(default_input, Initializer):
|
|
|
|
|
if not isinstance(obj, Tensor):
|
|
|
|
|
obj.init_mode = default_input
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
@ -112,11 +109,10 @@ class Parameter(MetaTensor):
|
|
|
|
|
if isinstance(data, bool):
|
|
|
|
|
raise ValueError('Parameter data can not be `bool`')
|
|
|
|
|
if isinstance(data, Initializer):
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
|
# always init data while in pynative mode.
|
|
|
|
|
data = data.to_tensor()
|
|
|
|
|
else:
|
|
|
|
|
if _is_in_parallel_mode():
|
|
|
|
|
# do not init data while in auto parallel.
|
|
|
|
|
return (MetaTensor, data.dtype, data.shape)
|
|
|
|
|
data = data.to_tensor()
|
|
|
|
|
if isinstance(data, Tensor):
|
|
|
|
|
# make a copy of Tensor to init the parameter
|
|
|
|
|
return (Tensor, data.asnumpy(),)
|
|
|
|
|
@ -127,9 +123,9 @@ class Parameter(MetaTensor):
|
|
|
|
|
return (Tensor, data)
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
value_str = MetaTensor.__repr__(self)
|
|
|
|
|
value_str = MetaTensor.__str__(self)
|
|
|
|
|
if isinstance(self, Tensor):
|
|
|
|
|
value_str = Tensor.__repr__(self)
|
|
|
|
|
value_str = Tensor.__str__(self)
|
|
|
|
|
return f'Parameter (name={self._value.name}, value={value_str})'
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
@ -235,8 +231,6 @@ class Parameter(MetaTensor):
|
|
|
|
|
shape = self.shape
|
|
|
|
|
dtype = self.dtype
|
|
|
|
|
x.default_input = initializer(init, shape=shape, dtype=dtype)
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
|
x.init_data()
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
@ -381,8 +375,12 @@ class ParameterTuple(tuple):
|
|
|
|
|
"""
|
|
|
|
|
def __new__(cls, iterable):
|
|
|
|
|
"""Create instance object of ParameterTuple."""
|
|
|
|
|
g = (x for x in iterable if _check_type(x))
|
|
|
|
|
return tuple.__new__(ParameterTuple, g)
|
|
|
|
|
data = tuple(iterable)
|
|
|
|
|
for x in data:
|
|
|
|
|
if not isinstance(x, Parameter):
|
|
|
|
|
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
|
|
|
|
|
f"But got a {type(iterable)}, {iterable}")
|
|
|
|
|
return tuple.__new__(ParameterTuple, tuple(data))
|
|
|
|
|
|
|
|
|
|
def clone(self, prefix, init='same'):
|
|
|
|
|
"""
|
|
|
|
|
|