From 2de97f256fbdefed35340f1cc021a3f89227a7d2 Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Tue, 31 Mar 2020 16:12:43 +0800 Subject: [PATCH] iterfaces change: _Constant to Constant --- mindspore/common/initializer.py | 21 +++++++++++++++------ tests/ut/python/utils/test_initializer.py | 9 +++++++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 99b4501307..bdc3418129 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -180,18 +180,18 @@ class HeUniform(Initializer): _assignment(arr, data) -class _Constant(Initializer): +class Constant(Initializer): """ Initialize a constant. Args: - value (int or numpy.ndarray): The value to initialize. + value (Union[int, numpy.ndarray]): The value to initialize. Returns: Array, initialize array. """ def __init__(self, value): - super(_Constant, self).__init__(value=value) + super(Constant, self).__init__(value=value) self.value = value def _initialize(self, arr): @@ -266,8 +266,16 @@ def initializer(init, shape=None, dtype=mstype.float32): Args: init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value. + + - `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding + class will be called. + + - `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor. + + - `numbers.Number`: The `Constant` will be called to initialize tensor. + shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of - output. Default: None. + output. Default: None. dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mstype.float32. Returns: @@ -295,7 +303,7 @@ def initializer(init, shape=None, dtype=mstype.float32): raise ValueError(msg) if isinstance(init, numbers.Number): - init_obj = _Constant(init) + init_obj = Constant(init) elif isinstance(init, str): init_obj = _INITIALIZER_ALIAS[init.lower()]() else: @@ -314,4 +322,5 @@ __all__ = [ 'HeUniform', 'XavierUniform', 'One', - 'Zero'] + 'Zero', + 'Constant'] diff --git a/tests/ut/python/utils/test_initializer.py b/tests/ut/python/utils/test_initializer.py index ff7ab8d119..31d2434341 100644 --- a/tests/ut/python/utils/test_initializer.py +++ b/tests/ut/python/utils/test_initializer.py @@ -37,8 +37,8 @@ def _check_value(tensor, value_min, value_max): for ele in nd.flatten(): if value_min <= ele <= value_max: continue - raise TypeError('value_min = %d, ele = %d, value_max = %d' - % (value_min, ele, value_max)) + raise ValueError('value_min = %d, ele = %d, value_max = %d' + % (value_min, ele, value_max)) def _check_uniform(tensor, boundary_a, boundary_b): @@ -92,6 +92,11 @@ def test_init_one_alias(): _check_value(tensor, 1, 1) +def test_init_constant(): + tensor = init.initializer(init.Constant(1), [2, 2], ms.float32) + _check_value(tensor, 1, 1) + + def test_init_uniform(): scale = 10 tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32)