diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 820c5b59de..54c0a1debe 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -338,12 +338,6 @@ def initializer(init, shape=None, dtype=mstype.float32): "the variable shape {}.".format(list(init.shape()), shape)) return init - if isinstance(init, str): - init_obj = _INITIALIZER_ALIAS[init.lower()]() - if init_obj is None: - raise ValueError("The class corresponding to '{}' was not found.".format(init)) - init = init_obj - if isinstance(shape, list): shape = tuple(shape) elif isinstance(shape, numbers.Number): @@ -354,6 +348,15 @@ def initializer(init, shape=None, dtype=mstype.float32): raise ValueError("Error shape={}".format(shape)) if isinstance(init, Initializer): + init.shape = init.shape if init.shape is not None else shape + init.dtype = init.dtype if init.dtype is not None else dtype + return init + + if isinstance(init, str): + init_obj = _INITIALIZER_ALIAS[init.lower()]() + if init_obj is None: + raise ValueError("The class corresponding to '{}' was not found.".format(init)) + init = init_obj init.shape = shape init.dtype = dtype return init diff --git a/tests/ut/python/utils/test_initializer.py b/tests/ut/python/utils/test_initializer.py index 417d0bb2b1..57709baa76 100644 --- a/tests/ut/python/utils/test_initializer.py +++ b/tests/ut/python/utils/test_initializer.py @@ -141,7 +141,18 @@ def test_init_abnormal(): with py.raises(TypeError): init.initializer([''], [5, 4], ms.float32) - +def test_initializer_reinit(): + weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16) + assert weights.dtype == ms.float16 + assert weights.shape == (10, 1, 10, 10) + weights = init.initializer(weights) + assert weights.dtype == ms.float16 + assert weights.shape == (10, 1, 10, 10) + weights.shape = None + weights = init.initializer(weights, (10, 1)) + assert weights.dtype == ms.float16 + assert weights.shape == (10, 1) + def test_init_xavier_uniform(): """ test_init_xavier_uniform """ gain = 1.2