Use MetaTensor instead of Initializer

pull/7373/head
lilei 5 years ago
parent 56b7562bf5
commit 7b135990e7

@ -15,16 +15,13 @@
"""Initializer for cell parameters.""" """Initializer for cell parameters."""
import numbers import numbers
import math import math
import copy
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from scipy.stats import truncnorm from scipy.stats import truncnorm
from mindspore import log as logger
from . import dtype as mstype from . import dtype as mstype
from .tensor import Tensor from .tensor import Tensor, MetaTensor
from .seed import get_seed
from .._c_expression import random_normal from .._c_expression import random_normal
_INITIALIZER_ALIAS = dict() _INITIALIZER_ALIAS = dict()
@ -52,54 +49,6 @@ class Initializer:
def __call__(self, arr): def __call__(self, arr):
return self._initialize(arr) return self._initialize(arr)
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
self._shape = shape
@property
def dtype(self):
return self._dtype
@dtype.setter
def dtype(self, dtype):
self._dtype = dtype
def to_tensor(self, slice_index=None, shape=None):
"""
Get the tensor format data of this Initializer.
Args:
slice_index (int): Slice index of a parameter's slices.
It is used when initialize a slice of a parameter, it guarantees that devices
using the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
"""
arr = None
if shape is None:
shape = self.shape
try:
arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
except ValueError:
msg = "Error shape={}".format(shape)
logger.error(msg)
raise ValueError(msg)
global_seed = get_seed()
need_set_seed = ((slice_index is not None) and (global_seed is None))
seed_saved = np.random.get_state()[1][0]
if need_set_seed:
np.random.seed(slice_index)
self.__call__(arr)
if need_set_seed:
np.random.seed(seed_saved)
return Tensor(arr, dtype=self.dtype)
def _register(*aliases): def _register(*aliases):
"""Return the alias register.""" """Return the alias register."""
def alias_reg(cls): def alias_reg(cls):
@ -478,27 +427,16 @@ def initializer(init, shape=None, dtype=mstype.float32):
if not isinstance(value, int) or value <= 0: if not isinstance(value, int) or value <= 0:
raise ValueError(f"shape is invalid, shape value must be positive integer, shape:{shape}") raise ValueError(f"shape is invalid, shape value must be positive integer, shape:{shape}")
if isinstance(init, Initializer):
init_copy = copy.deepcopy(init)
init_copy.shape = shape if shape is not None else init.shape
init_copy.dtype = init.dtype if init.dtype is not None else dtype
return init_copy
if isinstance(init, str): if isinstance(init, str):
init_obj = _INITIALIZER_ALIAS[init.lower()]() init = _INITIALIZER_ALIAS[init.lower()]()
if init_obj is None: if init is None:
raise ValueError("The class corresponding to '{}' was not found.".format(init)) raise ValueError("The class corresponding to '{}' was not found.".format(init))
init = init_obj elif isinstance(init, numbers.Number):
init.shape = shape init = Constant(init)
init.dtype = dtype shape = shape if shape is not None else init.shape
return init dtype = init.dtype if init.dtype is not None else dtype
init_obj = MetaTensor(init, dtype, shape)
if isinstance(init, numbers.Number): return init_obj
init_obj = Constant(init)
init_obj.shape = shape
init_obj.dtype = dtype
return init_obj
raise TypeError("Unsupported init type '{}'.".format(type(init)))
__all__ = [ __all__ = [
'Initializer', 'Initializer',

@ -16,8 +16,9 @@
"""Parameter for cell.""" """Parameter for cell."""
from copy import copy from copy import copy
from .._c_expression import ParamInfo from .._c_expression import ParamInfo
from .._c_expression import MetaTensor as MetaTensor_
from . import dtype as mstype from . import dtype as mstype
from .initializer import initializer, Initializer from .initializer import initializer
from .tensor import Tensor, MetaTensor from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular from .._checkparam import _check_str_by_regular
from ..parallel._tensor import _get_slice_index from ..parallel._tensor import _get_slice_index
@ -34,14 +35,14 @@ def _is_in_parallel_mode():
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
class Parameter(MetaTensor): class Parameter(MetaTensor_):
""" """
Parameter types of cell models. Parameter types of cell models.
After initialized `Parameter` is a subtype of `Tensor`. After initialized `Parameter` is a subtype of `Tensor`.
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor` an `MetaTensor`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor_`
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
@ -52,7 +53,7 @@ class Parameter(MetaTensor):
then the Parameters as this part of the inputs are not allowed to be cast. then the Parameters as this part of the inputs are not allowed to be cast.
Args: Args:
default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized. default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized.
name (str): Name of the child parameter. name (str): Name of the child parameter.
requires_grad (bool): True if the parameter requires gradient. Default: True. requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode, layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode,
@ -94,9 +95,9 @@ class Parameter(MetaTensor):
input_class.__init__(obj, *class_init_args) input_class.__init__(obj, *class_init_args)
# it's better to make the Initializer a kind of metatensor. # it's better to make the Initializer a kind of metatensor.
obj.init_mode = None obj.init_mode = None
obj.is_default_input_initializer = False obj.is_default_input_meta = False
if isinstance(default_input, Initializer): if isinstance(default_input, MetaTensor):
obj.is_default_input_initializer = True obj.is_default_input_meta = True
if not isinstance(obj, Tensor): if not isinstance(obj, Tensor):
obj.init_mode = default_input obj.init_mode = default_input
return obj return obj
@ -142,10 +143,10 @@ class Parameter(MetaTensor):
"""Set `set_data` of current `Parameter`.""" """Set `set_data` of current `Parameter`."""
if isinstance(data, bool): if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`') raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Initializer): if isinstance(data, MetaTensor):
if _is_in_parallel_mode(): if _is_in_parallel_mode():
# do not init data while in auto parallel. # do not init data while in auto parallel.
return (MetaTensor, data.dtype, data.shape) return (MetaTensor_, data.dtype, data.shape)
data = data.to_tensor() data = data.to_tensor()
if isinstance(data, Tensor): if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter # make a copy of Tensor to init the parameter
@ -257,7 +258,7 @@ class Parameter(MetaTensor):
Args: Args:
prefix (str): Namespace of parameter. The cloned Parameter name is prefix (str): Namespace of parameter. The cloned Parameter name is
combined of prefix and current name: `f"{perfix}.{self.name}"`. combined of prefix and current name: `f"{perfix}.{self.name}"`.
init (Union[Tensor, str, Initializer, numbers.Number]): Initialize the shape of the parameter. init (Union[Tensor, str, MetaTensor, numbers.Number]): Initialize the shape of the parameter.
Default: 'same'. Default: 'same'.
Returns: Returns:
@ -314,7 +315,7 @@ class Parameter(MetaTensor):
Set `set_data` of current `Parameter`. Set `set_data` of current `Parameter`.
Args: Args:
data (Union[Tensor, Initializer, int, float]): new data. data (Union[Tensor, MetaTensor, int, float]): new data.
slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False. slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False.
Retruns: Retruns:
@ -325,9 +326,9 @@ class Parameter(MetaTensor):
f"Current dtype is {self.dtype}, and incoming is {incoming}. " f"Current dtype is {self.dtype}, and incoming is {incoming}. "
f"Use .set_dtype(xxx) to change the dtype.") f"Use .set_dtype(xxx) to change the dtype.")
if not isinstance(data, (MetaTensor, Initializer, int, float)): if not isinstance(data, (MetaTensor_, int, float)):
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` " raise TypeError(f"Parameter data must be [`MetaTensor`, `int`, `float`] or a kind of `MetaTensor_` "
f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.") f"(like `Tensor` or `MetaTensor_`). But with type {type(data)}.")
if isinstance(data, (int, float)): if isinstance(data, (int, float)):
if self.dtype in mstype.int_type and isinstance(data, float): if self.dtype in mstype.int_type and isinstance(data, float):
raise_type_error(mstype.float_) raise_type_error(mstype.float_)
@ -337,8 +338,8 @@ class Parameter(MetaTensor):
is_current_tensor = isinstance(self, Tensor) is_current_tensor = isinstance(self, Tensor)
if is_incoming_tensor and not is_current_tensor: if is_incoming_tensor and not is_current_tensor:
raise TypeError("Parameter is a `MetaTensor` and not initializered, `data` for `set_data`" raise TypeError("Parameter is a `MetaTensor_` and not initializered, `data` for `set_data`"
"should be a Initializer. If you want to update it by Tensor, call method" "should be a MetaTensor. If you want to update it by Tensor, call method"
"`init_parameters_data` of `Cell` to init and replace all the Parameter of" "`init_parameters_data` of `Cell` to init and replace all the Parameter of"
"network, then call this method.") "network, then call this method.")
if tuple(self.shape) != tuple(data.shape): if tuple(self.shape) != tuple(data.shape):
@ -351,7 +352,7 @@ class Parameter(MetaTensor):
raise_type_error(data.dtype) raise_type_error(data.dtype)
else: else:
data = Tensor(data, self.dtype) data = Tensor(data, self.dtype)
if isinstance(data, Initializer): if isinstance(data, MetaTensor):
# The parameter has been initializered, directly update by the data # The parameter has been initializered, directly update by the data
if is_current_tensor: if is_current_tensor:
self._update_tensor_data(data.to_tensor()) self._update_tensor_data(data.to_tensor())
@ -387,10 +388,10 @@ class Parameter(MetaTensor):
Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
returns the same initialized `Parameter`. returns the same initialized `Parameter`.
""" """
if self.is_default_input_initializer: if self.is_default_input_meta:
is_current_in_parallel = _is_in_parallel_mode() is_current_in_parallel = _is_in_parallel_mode()
if self.is_in_parallel != is_current_in_parallel: if self.is_in_parallel != is_current_in_parallel:
raise RuntimeError("Must set or change parallel mode before any Initializer created.") raise RuntimeError("Must set or change parallel mode before any MetaTensor created.")
if self.init_mode is None: if self.init_mode is None:
return self return self
if self.inited_param is not None: if self.inited_param is not None:
@ -401,12 +402,12 @@ class Parameter(MetaTensor):
if len(layout) < 3: if len(layout) < 3:
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1])) slice_index = int(_get_slice_index(layout[0], layout[1]))
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
data = self.init_mode.to_tensor(0, [1]) data = self.init_mode.to_tensor(0, [1])
else: else:
data = self.init_mode.to_tensor(slice_index, layout[2]) data = self.init_mode.to_tensor(slice_index, layout[2])
else: else:
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
data = self.init_mode.to_tensor(0, [1]) data = self.init_mode.to_tensor(0, [1])
else: else:
data = self.init_mode.to_tensor() data = self.init_mode.to_tensor()

@ -15,8 +15,9 @@
"""Tensor implementation.""" """Tensor implementation."""
import numpy as np import numpy as np
from mindspore import log as logger
from .._c_expression import Tensor as Tensor_ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import check_type, check_typename from .._checkparam import check_type, check_typename
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
@ -395,6 +396,50 @@ class SparseTensor:
return self.__dense_shape return self.__dense_shape
class MetaTensor(MetaTensor_):
"""
The base class of the MetaTensor.
Initialization of tensor basic attributes and model weight values.
Returns:
Array, an array after being initialized.
"""
def __init__(self, init, dtype, shape):
#check param
self.init = init
MetaTensor_.__init__(self, dtype, shape)
def to_tensor(self, slice_index=None, shape=None):
"""
Get the tensor format data of this MetaTensor.
Args:
slice_index (int): Slice index of a parameter's slices.
It is used when initialize a slice of a parameter, it guarantees that devices
using the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
"""
if shape is None:
shape = self.shape
try:
arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
except ValueError:
msg = "Error shape={}".format(shape)
logger.error(msg)
raise ValueError(msg)
from .seed import get_seed
global_seed = get_seed()
need_set_seed = ((slice_index is not None) and (global_seed is None))
seed_saved = np.random.get_state()[1][0]
if need_set_seed:
np.random.seed(slice_index)
self.init(arr)
if need_set_seed:
np.random.seed(seed_saved)
return Tensor(arr, dtype=self.dtype)
def _vm_compare(*args): def _vm_compare(*args):
"""Implement `vm_compare` for tensor.""" """Implement `vm_compare` for tensor."""
obj_str = args[-1] obj_str = args[-1]

@ -32,7 +32,6 @@ from mindspore.nn import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as CP from mindspore.ops import composite as CP
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.common.initializer import initializer
from mindspore.nn.wrap.cell_wrapper import WithLossCell from mindspore.nn.wrap.cell_wrapper import WithLossCell
random.seed(1) random.seed(1)
@ -43,14 +42,6 @@ ds.config.set_seed(1)
grad_by_list = CP.GradOperation(get_by_list=True) grad_by_list = CP.GradOperation(get_by_list=True)
def weight_variable(shape):
return initializer('XavierUniform', shape=shape, dtype=mstype.float32)
def weight_variable_uniform(shape):
return initializer('Uniform', shape=shape, dtype=mstype.float32)
def weight_variable_0(shape): def weight_variable_0(shape):
zeros = np.zeros(shape).astype(np.float32) zeros = np.zeros(shape).astype(np.float32)
return Tensor(zeros) return Tensor(zeros)
@ -63,26 +54,23 @@ def weight_variable_1(shape):
def conv3x3(in_channels, out_channels, stride=1, padding=0): def conv3x3(in_channels, out_channels, stride=1, padding=0):
"""3x3 convolution """ """3x3 convolution """
weight_shape = (out_channels, in_channels, 3, 3)
weight = weight_variable(weight_shape)
return nn.Conv2d(in_channels, out_channels, return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") kernel_size=3, stride=stride, padding=padding, weight_init='XavierUniform',
has_bias=False, pad_mode="same")
def conv1x1(in_channels, out_channels, stride=1, padding=0): def conv1x1(in_channels, out_channels, stride=1, padding=0):
"""1x1 convolution""" """1x1 convolution"""
weight_shape = (out_channels, in_channels, 1, 1)
weight = weight_variable(weight_shape)
return nn.Conv2d(in_channels, out_channels, return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") kernel_size=1, stride=stride, padding=padding, weight_init='XavierUniform',
has_bias=False, pad_mode="same")
def conv7x7(in_channels, out_channels, stride=1, padding=0): def conv7x7(in_channels, out_channels, stride=1, padding=0):
"""1x1 convolution""" """1x1 convolution"""
weight_shape = (out_channels, in_channels, 7, 7)
weight = weight_variable(weight_shape)
return nn.Conv2d(in_channels, out_channels, return nn.Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") kernel_size=7, stride=stride, padding=padding, weight_init='XavierUniform',
has_bias=False, pad_mode="same")
def bn_with_initialize(out_channels): def bn_with_initialize(out_channels):
@ -90,8 +78,7 @@ def bn_with_initialize(out_channels):
mean = weight_variable_0(shape) mean = weight_variable_0(shape)
var = weight_variable_1(shape) var = weight_variable_1(shape)
beta = weight_variable_0(shape) beta = weight_variable_0(shape)
gamma = weight_variable_uniform(shape) bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform',
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
beta_init=beta, moving_mean_init=mean, moving_var_init=var) beta_init=beta, moving_mean_init=mean, moving_var_init=var)
return bn return bn
@ -101,18 +88,13 @@ def bn_with_initialize_last(out_channels):
mean = weight_variable_0(shape) mean = weight_variable_0(shape)
var = weight_variable_1(shape) var = weight_variable_1(shape)
beta = weight_variable_0(shape) beta = weight_variable_0(shape)
gamma = weight_variable_uniform(shape) bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform',
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
beta_init=beta, moving_mean_init=mean, moving_var_init=var) beta_init=beta, moving_mean_init=mean, moving_var_init=var)
return bn return bn
def fc_with_initialize(input_channels, out_channels): def fc_with_initialize(input_channels, out_channels):
weight_shape = (out_channels, input_channels) return nn.Dense(input_channels, out_channels, weight_init='XavierUniform', bias_init='Uniform')
weight = weight_variable(weight_shape)
bias_shape = (out_channels)
bias = weight_variable_uniform(bias_shape)
return nn.Dense(input_channels, out_channels, weight, bias)
class ResidualBlock(nn.Cell): class ResidualBlock(nn.Cell):

@ -20,7 +20,7 @@ import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.common import MetaTensor from mindspore._c_expression import MetaTensor as MetaTensor_
from mindspore.common import dtype from mindspore.common import dtype
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -70,8 +70,8 @@ def scalar_mul_while(x):
return rv return rv
@ms_function(input_signature=(MetaTensor(dtype.float32, (1, 1, 3, 3)), @ms_function(input_signature=(MetaTensor_(dtype.float32, (1, 1, 3, 3)),
MetaTensor(dtype.float32, (1, 1, 3, 3)))) MetaTensor_(dtype.float32, (1, 1, 3, 3))))
def tensor_add_test(x, y): def tensor_add_test(x, y):
""" tensor_add_test """ """ tensor_add_test """
z = F.tensor_add(x, y) z = F.tensor_add(x, y)

@ -24,7 +24,7 @@ import mindspore.common.initializer as init
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor, MetaTensor
from mindspore.nn import Conv2d from mindspore.nn import Conv2d
from mindspore.ops import operations as P from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
@ -58,7 +58,7 @@ def _check_uniform(tensor, boundary_a, boundary_b):
def test_init_Initializer(): def test_init_Initializer():
tensor = init.initializer(InitTwo(), [2, 2], ms.int32) tensor = init.initializer(InitTwo(), [2, 2], ms.int32)
assert tensor.shape == (2, 2) assert tensor.shape == [2, 2]
_check_value(tensor.to_tensor(), 2, 2) _check_value(tensor.to_tensor(), 2, 2)
@ -119,22 +119,22 @@ def test_init_uniform_alias():
def test_init_normal(): def test_init_normal():
tensor = init.initializer(init.Normal(), [5, 4], ms.float32) tensor = init.initializer(init.Normal(), [5, 4], ms.float32)
assert isinstance(tensor, init.Normal), 'Normal init failed!' assert isinstance(tensor, MetaTensor), 'Normal init failed!'
def test_init_truncated_normal(): def test_init_truncated_normal():
tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32) tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32)
assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!' assert isinstance(tensor, MetaTensor), 'TruncatedNormal init failed!'
def test_init_normal_alias(): def test_init_normal_alias():
tensor = init.initializer('normal', [5, 4], ms.float32) tensor = init.initializer('normal', [5, 4], ms.float32)
assert isinstance(tensor, init.Normal), 'Normal init failed!' assert isinstance(tensor, MetaTensor), 'Normal init failed!'
def test_init_truncatednormal_alias(): def test_init_truncatednormal_alias():
tensor = init.initializer('truncatednormal', [5, 4], ms.float32) tensor = init.initializer('truncatednormal', [5, 4], ms.float32)
assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!' assert isinstance(tensor, MetaTensor), 'TruncatedNormal init failed!'
def test_init_abnormal(): def test_init_abnormal():
@ -144,15 +144,7 @@ def test_init_abnormal():
def test_initializer_reinit(): def test_initializer_reinit():
weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16) weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16)
assert weights.dtype == ms.float16 assert isinstance(weights, MetaTensor), 'XavierUniform init failed!'
assert weights.shape == (10, 1, 10, 10)
weights = init.initializer(weights)
assert weights.dtype == ms.float16
assert weights.shape == (10, 1, 10, 10)
weights.shape = None
weights = init.initializer(weights, (10, 1))
assert weights.dtype == ms.float16
assert weights.shape == (10, 1)
def test_init_xavier_uniform(): def test_init_xavier_uniform():

Loading…
Cancel
Save