From 7b135990e77c0ab229b341066c5a283271686ff7 Mon Sep 17 00:00:00 2001 From: lilei Date: Thu, 15 Oct 2020 22:17:06 +0800 Subject: [PATCH] Use MetaTensor instead of Initializer --- mindspore/common/initializer.py | 80 +++---------------- mindspore/common/parameter.py | 43 +++++----- mindspore/common/tensor.py | 47 ++++++++++- tests/st/pynative/test_pynative_resnet50.py | 36 +++------ tests/ut/python/pynative_mode/test_staging.py | 6 +- tests/ut/python/utils/test_initializer.py | 22 ++--- 6 files changed, 96 insertions(+), 138 deletions(-) diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 24faa74ac0..09319fbbd5 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -15,16 +15,13 @@ """Initializer for cell parameters.""" import numbers import math -import copy from functools import reduce import numpy as np from scipy.stats import truncnorm -from mindspore import log as logger from . import dtype as mstype -from .tensor import Tensor -from .seed import get_seed +from .tensor import Tensor, MetaTensor from .._c_expression import random_normal _INITIALIZER_ALIAS = dict() @@ -52,54 +49,6 @@ class Initializer: def __call__(self, arr): return self._initialize(arr) - @property - def shape(self): - return self._shape - - @shape.setter - def shape(self, shape): - self._shape = shape - - @property - def dtype(self): - return self._dtype - - @dtype.setter - def dtype(self, dtype): - self._dtype = dtype - - def to_tensor(self, slice_index=None, shape=None): - """ - Get the tensor format data of this Initializer. - - Args: - slice_index (int): Slice index of a parameter's slices. - It is used when initialize a slice of a parameter, it guarantees that devices - using the same slice can generate the same tensor. - shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. - """ - arr = None - if shape is None: - shape = self.shape - - try: - arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype)) - except ValueError: - msg = "Error shape={}".format(shape) - logger.error(msg) - raise ValueError(msg) - - global_seed = get_seed() - need_set_seed = ((slice_index is not None) and (global_seed is None)) - seed_saved = np.random.get_state()[1][0] - if need_set_seed: - np.random.seed(slice_index) - self.__call__(arr) - if need_set_seed: - np.random.seed(seed_saved) - return Tensor(arr, dtype=self.dtype) - - def _register(*aliases): """Return the alias register.""" def alias_reg(cls): @@ -478,27 +427,16 @@ def initializer(init, shape=None, dtype=mstype.float32): if not isinstance(value, int) or value <= 0: raise ValueError(f"shape is invalid, shape value must be positive integer, shape:{shape}") - if isinstance(init, Initializer): - init_copy = copy.deepcopy(init) - init_copy.shape = shape if shape is not None else init.shape - init_copy.dtype = init.dtype if init.dtype is not None else dtype - return init_copy - if isinstance(init, str): - init_obj = _INITIALIZER_ALIAS[init.lower()]() - if init_obj is None: + init = _INITIALIZER_ALIAS[init.lower()]() + if init is None: raise ValueError("The class corresponding to '{}' was not found.".format(init)) - init = init_obj - init.shape = shape - init.dtype = dtype - return init - - if isinstance(init, numbers.Number): - init_obj = Constant(init) - init_obj.shape = shape - init_obj.dtype = dtype - return init_obj - raise TypeError("Unsupported init type '{}'.".format(type(init))) + elif isinstance(init, numbers.Number): + init = Constant(init) + shape = shape if shape is not None else init.shape + dtype = init.dtype if init.dtype is not None else dtype + init_obj = MetaTensor(init, dtype, shape) + return init_obj __all__ = [ 'Initializer', diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 6ad4cb19d2..0ce6812862 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -16,8 +16,9 @@ """Parameter for cell.""" from copy import copy from .._c_expression import ParamInfo +from .._c_expression import MetaTensor as MetaTensor_ from . import dtype as mstype -from .initializer import initializer, Initializer +from .initializer import initializer from .tensor import Tensor, MetaTensor from .._checkparam import _check_str_by_regular from ..parallel._tensor import _get_slice_index @@ -34,14 +35,14 @@ def _is_in_parallel_mode(): return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] -class Parameter(MetaTensor): +class Parameter(MetaTensor_): """ Parameter types of cell models. After initialized `Parameter` is a subtype of `Tensor`. In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by - an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor` + an `MetaTensor`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor_` only saves the shape and type info of a tensor with no memory usage. The shape can be changed while compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. @@ -52,7 +53,7 @@ class Parameter(MetaTensor): then the Parameters as this part of the inputs are not allowed to be cast. Args: - default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized. + default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized. name (str): Name of the child parameter. requires_grad (bool): True if the parameter requires gradient. Default: True. layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode, @@ -94,9 +95,9 @@ class Parameter(MetaTensor): input_class.__init__(obj, *class_init_args) # it's better to make the Initializer a kind of metatensor. obj.init_mode = None - obj.is_default_input_initializer = False - if isinstance(default_input, Initializer): - obj.is_default_input_initializer = True + obj.is_default_input_meta = False + if isinstance(default_input, MetaTensor): + obj.is_default_input_meta = True if not isinstance(obj, Tensor): obj.init_mode = default_input return obj @@ -142,10 +143,10 @@ class Parameter(MetaTensor): """Set `set_data` of current `Parameter`.""" if isinstance(data, bool): raise ValueError('Parameter data can not be `bool`') - if isinstance(data, Initializer): + if isinstance(data, MetaTensor): if _is_in_parallel_mode(): # do not init data while in auto parallel. - return (MetaTensor, data.dtype, data.shape) + return (MetaTensor_, data.dtype, data.shape) data = data.to_tensor() if isinstance(data, Tensor): # make a copy of Tensor to init the parameter @@ -257,7 +258,7 @@ class Parameter(MetaTensor): Args: prefix (str): Namespace of parameter. The cloned Parameter name is combined of prefix and current name: `f"{perfix}.{self.name}"`. - init (Union[Tensor, str, Initializer, numbers.Number]): Initialize the shape of the parameter. + init (Union[Tensor, str, MetaTensor, numbers.Number]): Initialize the shape of the parameter. Default: 'same'. Returns: @@ -314,7 +315,7 @@ class Parameter(MetaTensor): Set `set_data` of current `Parameter`. Args: - data (Union[Tensor, Initializer, int, float]): new data. + data (Union[Tensor, MetaTensor, int, float]): new data. slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False. Retruns: @@ -325,9 +326,9 @@ class Parameter(MetaTensor): f"Current dtype is {self.dtype}, and incoming is {incoming}. " f"Use .set_dtype(xxx) to change the dtype.") - if not isinstance(data, (MetaTensor, Initializer, int, float)): - raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` " - f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.") + if not isinstance(data, (MetaTensor_, int, float)): + raise TypeError(f"Parameter data must be [`MetaTensor`, `int`, `float`] or a kind of `MetaTensor_` " + f"(like `Tensor` or `MetaTensor_`). But with type {type(data)}.") if isinstance(data, (int, float)): if self.dtype in mstype.int_type and isinstance(data, float): raise_type_error(mstype.float_) @@ -337,8 +338,8 @@ class Parameter(MetaTensor): is_current_tensor = isinstance(self, Tensor) if is_incoming_tensor and not is_current_tensor: - raise TypeError("Parameter is a `MetaTensor` and not initializered, `data` for `set_data`" - "should be a Initializer. If you want to update it by Tensor, call method" + raise TypeError("Parameter is a `MetaTensor_` and not initializered, `data` for `set_data`" + "should be a MetaTensor. If you want to update it by Tensor, call method" "`init_parameters_data` of `Cell` to init and replace all the Parameter of" "network, then call this method.") if tuple(self.shape) != tuple(data.shape): @@ -351,7 +352,7 @@ class Parameter(MetaTensor): raise_type_error(data.dtype) else: data = Tensor(data, self.dtype) - if isinstance(data, Initializer): + if isinstance(data, MetaTensor): # The parameter has been initializered, directly update by the data if is_current_tensor: self._update_tensor_data(data.to_tensor()) @@ -387,10 +388,10 @@ class Parameter(MetaTensor): Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, returns the same initialized `Parameter`. """ - if self.is_default_input_initializer: + if self.is_default_input_meta: is_current_in_parallel = _is_in_parallel_mode() if self.is_in_parallel != is_current_in_parallel: - raise RuntimeError("Must set or change parallel mode before any Initializer created.") + raise RuntimeError("Must set or change parallel mode before any MetaTensor created.") if self.init_mode is None: return self if self.inited_param is not None: @@ -401,12 +402,12 @@ class Parameter(MetaTensor): if len(layout) < 3: raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) slice_index = int(_get_slice_index(layout[0], layout[1])) - if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): + if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): data = self.init_mode.to_tensor(0, [1]) else: data = self.init_mode.to_tensor(slice_index, layout[2]) else: - if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): + if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): data = self.init_mode.to_tensor(0, [1]) else: data = self.init_mode.to_tensor() diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 643e2873d6..636789a617 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -15,8 +15,9 @@ """Tensor implementation.""" import numpy as np +from mindspore import log as logger from .._c_expression import Tensor as Tensor_ -from .._c_expression import MetaTensor +from .._c_expression import MetaTensor as MetaTensor_ from .._checkparam import check_type, check_typename from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry @@ -395,6 +396,50 @@ class SparseTensor: return self.__dense_shape +class MetaTensor(MetaTensor_): + """ + The base class of the MetaTensor. + Initialization of tensor basic attributes and model weight values. + + Returns: + Array, an array after being initialized. + """ + def __init__(self, init, dtype, shape): + #check param + self.init = init + MetaTensor_.__init__(self, dtype, shape) + + def to_tensor(self, slice_index=None, shape=None): + """ + Get the tensor format data of this MetaTensor. + + Args: + slice_index (int): Slice index of a parameter's slices. + It is used when initialize a slice of a parameter, it guarantees that devices + using the same slice can generate the same tensor. + shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. + """ + if shape is None: + shape = self.shape + + try: + arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype)) + except ValueError: + msg = "Error shape={}".format(shape) + logger.error(msg) + raise ValueError(msg) + from .seed import get_seed + global_seed = get_seed() + need_set_seed = ((slice_index is not None) and (global_seed is None)) + seed_saved = np.random.get_state()[1][0] + if need_set_seed: + np.random.seed(slice_index) + self.init(arr) + if need_set_seed: + np.random.seed(seed_saved) + return Tensor(arr, dtype=self.dtype) + + def _vm_compare(*args): """Implement `vm_compare` for tensor.""" obj_str = args[-1] diff --git a/tests/st/pynative/test_pynative_resnet50.py b/tests/st/pynative/test_pynative_resnet50.py index 6364ac7d3a..b54e4a0161 100644 --- a/tests/st/pynative/test_pynative_resnet50.py +++ b/tests/st/pynative/test_pynative_resnet50.py @@ -32,7 +32,6 @@ from mindspore.nn import Cell from mindspore.ops import operations as P from mindspore.ops import composite as CP from mindspore.nn.optim.momentum import Momentum -from mindspore.common.initializer import initializer from mindspore.nn.wrap.cell_wrapper import WithLossCell random.seed(1) @@ -43,14 +42,6 @@ ds.config.set_seed(1) grad_by_list = CP.GradOperation(get_by_list=True) -def weight_variable(shape): - return initializer('XavierUniform', shape=shape, dtype=mstype.float32) - - -def weight_variable_uniform(shape): - return initializer('Uniform', shape=shape, dtype=mstype.float32) - - def weight_variable_0(shape): zeros = np.zeros(shape).astype(np.float32) return Tensor(zeros) @@ -63,26 +54,23 @@ def weight_variable_1(shape): def conv3x3(in_channels, out_channels, stride=1, padding=0): """3x3 convolution """ - weight_shape = (out_channels, in_channels, 3, 3) - weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=3, stride=stride, padding=padding, weight_init='XavierUniform', + has_bias=False, pad_mode="same") def conv1x1(in_channels, out_channels, stride=1, padding=0): """1x1 convolution""" - weight_shape = (out_channels, in_channels, 1, 1) - weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=1, stride=stride, padding=padding, weight_init='XavierUniform', + has_bias=False, pad_mode="same") def conv7x7(in_channels, out_channels, stride=1, padding=0): """1x1 convolution""" - weight_shape = (out_channels, in_channels, 7, 7) - weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=7, stride=stride, padding=padding, weight_init='XavierUniform', + has_bias=False, pad_mode="same") def bn_with_initialize(out_channels): @@ -90,8 +78,7 @@ def bn_with_initialize(out_channels): mean = weight_variable_0(shape) var = weight_variable_1(shape) beta = weight_variable_0(shape) - gamma = weight_variable_uniform(shape) - bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform', beta_init=beta, moving_mean_init=mean, moving_var_init=var) return bn @@ -101,18 +88,13 @@ def bn_with_initialize_last(out_channels): mean = weight_variable_0(shape) var = weight_variable_1(shape) beta = weight_variable_0(shape) - gamma = weight_variable_uniform(shape) - bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform', beta_init=beta, moving_mean_init=mean, moving_var_init=var) return bn def fc_with_initialize(input_channels, out_channels): - weight_shape = (out_channels, input_channels) - weight = weight_variable(weight_shape) - bias_shape = (out_channels) - bias = weight_variable_uniform(bias_shape) - return nn.Dense(input_channels, out_channels, weight, bias) + return nn.Dense(input_channels, out_channels, weight_init='XavierUniform', bias_init='Uniform') class ResidualBlock(nn.Cell): diff --git a/tests/ut/python/pynative_mode/test_staging.py b/tests/ut/python/pynative_mode/test_staging.py index c5e7396739..cc91ac8913 100644 --- a/tests/ut/python/pynative_mode/test_staging.py +++ b/tests/ut/python/pynative_mode/test_staging.py @@ -20,7 +20,7 @@ import mindspore as ms import mindspore.nn as nn from mindspore import Tensor from mindspore import context -from mindspore.common import MetaTensor +from mindspore._c_expression import MetaTensor as MetaTensor_ from mindspore.common import dtype from mindspore.common.api import ms_function from mindspore.ops import functional as F @@ -70,8 +70,8 @@ def scalar_mul_while(x): return rv -@ms_function(input_signature=(MetaTensor(dtype.float32, (1, 1, 3, 3)), - MetaTensor(dtype.float32, (1, 1, 3, 3)))) +@ms_function(input_signature=(MetaTensor_(dtype.float32, (1, 1, 3, 3)), + MetaTensor_(dtype.float32, (1, 1, 3, 3)))) def tensor_add_test(x, y): """ tensor_add_test """ z = F.tensor_add(x, y) diff --git a/tests/ut/python/utils/test_initializer.py b/tests/ut/python/utils/test_initializer.py index 1d2498b959..b2d8a800a1 100644 --- a/tests/ut/python/utils/test_initializer.py +++ b/tests/ut/python/utils/test_initializer.py @@ -24,7 +24,7 @@ import mindspore.common.initializer as init import mindspore.nn as nn from mindspore import context from mindspore.common.parameter import Parameter -from mindspore.common.tensor import Tensor +from mindspore.common.tensor import Tensor, MetaTensor from mindspore.nn import Conv2d from mindspore.ops import operations as P from ..ut_filter import non_graph_engine @@ -58,7 +58,7 @@ def _check_uniform(tensor, boundary_a, boundary_b): def test_init_Initializer(): tensor = init.initializer(InitTwo(), [2, 2], ms.int32) - assert tensor.shape == (2, 2) + assert tensor.shape == [2, 2] _check_value(tensor.to_tensor(), 2, 2) @@ -119,22 +119,22 @@ def test_init_uniform_alias(): def test_init_normal(): tensor = init.initializer(init.Normal(), [5, 4], ms.float32) - assert isinstance(tensor, init.Normal), 'Normal init failed!' + assert isinstance(tensor, MetaTensor), 'Normal init failed!' def test_init_truncated_normal(): tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32) - assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!' + assert isinstance(tensor, MetaTensor), 'TruncatedNormal init failed!' def test_init_normal_alias(): tensor = init.initializer('normal', [5, 4], ms.float32) - assert isinstance(tensor, init.Normal), 'Normal init failed!' + assert isinstance(tensor, MetaTensor), 'Normal init failed!' def test_init_truncatednormal_alias(): tensor = init.initializer('truncatednormal', [5, 4], ms.float32) - assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!' + assert isinstance(tensor, MetaTensor), 'TruncatedNormal init failed!' def test_init_abnormal(): @@ -144,15 +144,7 @@ def test_init_abnormal(): def test_initializer_reinit(): weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16) - assert weights.dtype == ms.float16 - assert weights.shape == (10, 1, 10, 10) - weights = init.initializer(weights) - assert weights.dtype == ms.float16 - assert weights.shape == (10, 1, 10, 10) - weights.shape = None - weights = init.initializer(weights, (10, 1)) - assert weights.dtype == ms.float16 - assert weights.shape == (10, 1) + assert isinstance(weights, MetaTensor), 'XavierUniform init failed!' def test_init_xavier_uniform():