diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 6c88b7d957..641558921a 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -344,5 +344,5 @@ class ParameterUpdate(Cell): self._param = param def construct(self, x): - self._param = x + F.assign(self._param, x) return x diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index e933d40666..49cc5318fa 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -408,10 +408,11 @@ def _fill_param_into_net(net, parameter_list): for each_param in parameter_list: param_name = each_param["name"] np_val = each_param["data"].asnumpy() - if np_val.shape == (1,): # to scalar - parameter_dict[param_name] = Parameter(np_val[0], name=param_name) + if np_val.shape == (1,): + parameter_dict[param_name] = Parameter(np_val, name=param_name) elif np_val.shape == (): - parameter_dict[param_name] = Parameter(np_val.tolist(), name=param_name) + parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)), + name=param_name) else: parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name) diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 529af532f7..d6bc40ba02 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -52,12 +52,69 @@ def test_parameter_tuple_illegal(): def test_parameter_init_illegal(): + import numpy as np + dat = np.array([[1, 2, 3], [2, 3, 4]]) + tensor = Tensor(dat) + data_none = None data_bool = True data_str = "nicai" + data_int = 3 + data_list = [1, "2", True] + data_tuple = (1, 2, 3) + + # test data + Parameter(tensor, name=data_str) + Parameter(data_int, name=data_str) + Parameter(dat, name=data_str) with pytest.raises(ValueError): Parameter(data_bool, name=data_str) + # test name + Parameter(tensor, name=data_none) + with pytest.raises(ValueError): + Parameter(tensor, name=dat) + with pytest.raises(ValueError): + Parameter(tensor, name=tensor) + with pytest.raises(ValueError): + Parameter(tensor, name=data_bool) + with pytest.raises(ValueError): + Parameter(tensor, name=data_int) + with pytest.raises(ValueError): + Parameter(tensor, name=data_list) + with pytest.raises(ValueError): + Parameter(tensor, name=data_tuple) + + Parameter(tensor, name=data_str, requires_grad=data_bool) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_none) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=dat) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=tensor) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_str) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_int) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_list) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_tuple) + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_bool) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=dat) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=tensor) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_none) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_str) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_int) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_list) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_tuple) def test_check_str_by_regular():