|
|
|
@ -280,15 +280,23 @@ class Parameter(MetaTensor):
|
|
|
|
|
Set `default_input` of current `Parameter`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
data (Union[Tensor, Initializer]): new data.
|
|
|
|
|
slice_shape (bool): If slice the Parameter. Default: False.
|
|
|
|
|
data (Union[Tensor, Initializer, int, float]): new data.
|
|
|
|
|
slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False.
|
|
|
|
|
|
|
|
|
|
Retruns:
|
|
|
|
|
Parameter, the parameter after set data.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(data, (MetaTensor, Initializer)):
|
|
|
|
|
raise ValueError(f"Parameter data must be `Initializer` or a kind of `MetaTensor` "
|
|
|
|
|
f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.")
|
|
|
|
|
def raise_type_error(incoming):
|
|
|
|
|
raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
|
|
|
|
|
f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.")
|
|
|
|
|
|
|
|
|
|
if not isinstance(data, (MetaTensor, Initializer, int, float)):
|
|
|
|
|
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
|
|
|
|
|
f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.")
|
|
|
|
|
if isinstance(data, (int, float)):
|
|
|
|
|
if self.dtype in mstype.int_type and isinstance(data, float):
|
|
|
|
|
raise_type_error(mstype.float_)
|
|
|
|
|
data = Tensor(data, self.dtype)
|
|
|
|
|
# both not init.
|
|
|
|
|
is_incoming_tensor = isinstance(data, Tensor)
|
|
|
|
|
is_current_tensor = isinstance(self, Tensor)
|
|
|
|
@ -300,25 +308,25 @@ class Parameter(MetaTensor):
|
|
|
|
|
"network, then call this method.")
|
|
|
|
|
if tuple(self.shape) != tuple(data.shape):
|
|
|
|
|
# If Slice create Parameter shape can be change.
|
|
|
|
|
if slice_shape:
|
|
|
|
|
self._update_tensor_data(data)
|
|
|
|
|
self.sliced = True
|
|
|
|
|
else:
|
|
|
|
|
if not slice_shape:
|
|
|
|
|
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
|
|
|
|
|
f" Current shape is {self.shape}, and incoming is {data.shape}.")
|
|
|
|
|
if self.dtype != data.dtype:
|
|
|
|
|
raise ValueError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
|
|
|
|
|
f", and incoming is {data.dtype}. Use .set_dtype(xxx) to change the dtype.")
|
|
|
|
|
raise_type_error(data.dtype)
|
|
|
|
|
if isinstance(data, Initializer):
|
|
|
|
|
# The parameter has been initializered, directly update by the data
|
|
|
|
|
if is_current_tensor:
|
|
|
|
|
self._update_tensor_data(data.to_tensor())
|
|
|
|
|
else:
|
|
|
|
|
# also update the related inited parameter data
|
|
|
|
|
if self.inited_param is not None:
|
|
|
|
|
self.inited_param.set_parameter_data(data)
|
|
|
|
|
self.init_mode = data
|
|
|
|
|
elif is_incoming_tensor or is_current_tensor:
|
|
|
|
|
self._update_tensor_data(data)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Not support to update the Parameter by {data}")
|
|
|
|
|
self.sliced = slice_shape
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def init_data(self, layout=None, set_sliced=False):
|
|
|
|
@ -340,8 +348,6 @@ class Parameter(MetaTensor):
|
|
|
|
|
"""
|
|
|
|
|
if self.init_mode is None:
|
|
|
|
|
return self
|
|
|
|
|
if self.inited_param is not None:
|
|
|
|
|
return self.inited_param
|
|
|
|
|
if layout is not None:
|
|
|
|
|
if not isinstance(layout, list):
|
|
|
|
|
raise TypeError("The layout should be list! layout is {}.".format(layout))
|
|
|
|
@ -362,8 +368,7 @@ class Parameter(MetaTensor):
|
|
|
|
|
if id(obj) != id(self):
|
|
|
|
|
self._inited_param = obj
|
|
|
|
|
obj.init_mode = None
|
|
|
|
|
if set_sliced:
|
|
|
|
|
obj.sliced = True
|
|
|
|
|
obj.sliced = set_sliced
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|