!7429 modify normal seed

Merge pull request !7429 from lilei/modify_bug
pull/7429/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b2b9016ddf

@ -19,7 +19,7 @@ import math
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from scipy.stats import truncnorm from scipy.stats import truncnorm
from .seed import _get_graph_seed
from . import dtype as mstype from . import dtype as mstype
from .tensor import Tensor, MetaTensor from .tensor import Tensor, MetaTensor
from .._c_expression import random_normal from .._c_expression import random_normal
@ -40,8 +40,19 @@ class Initializer:
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._kwargs = kwargs self._kwargs = kwargs
self.shape = None self._seed = None
self.dtype = None
@property
def seed(self):
seed_ = self._seed if self._seed is not None else 1
_, seed = _get_graph_seed(seed_, "init")
return seed
@seed.setter
def seed(self, value):
if not isinstance(value, int):
raise TypeError("'value' must be int type.")
self._seed = value
def _initialize(self, *kwargs): def _initialize(self, *kwargs):
raise NotImplementedError('Must be overridden!') raise NotImplementedError('Must be overridden!')
@ -353,7 +364,7 @@ class Normal(Initializer):
self.sigma = sigma self.sigma = sigma
def _initialize(self, arr): def _initialize(self, arr):
seed = np.random.get_state()[1][0] seed = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32)) output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, output_tensor) random_normal(0, self.sigma, arr.shape, seed, output_tensor)
output_data = output_tensor.asnumpy() output_data = output_tensor.asnumpy()
@ -434,8 +445,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
elif isinstance(init, numbers.Number): elif isinstance(init, numbers.Number):
init = Constant(init) init = Constant(init)
shape = shape if shape is not None else init.shape shape = shape if shape is not None else init.shape
dtype = init.dtype if init.dtype is not None else dtype init_obj = MetaTensor(dtype, shape, init)
init_obj = MetaTensor(init, dtype, shape)
return init_obj return init_obj
__all__ = [ __all__ = [

@ -404,7 +404,7 @@ class MetaTensor(MetaTensor_):
Returns: Returns:
Array, an array after being initialized. Array, an array after being initialized.
""" """
def __init__(self, init, dtype, shape): def __init__(self, dtype, shape, init=None):
#check param #check param
self.init = init self.init = init
MetaTensor_.__init__(self, dtype, shape) MetaTensor_.__init__(self, dtype, shape)
@ -419,6 +419,9 @@ class MetaTensor(MetaTensor_):
using the same slice can generate the same tensor. using the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
""" """
if self.init is None:
raise TypeError("to_dense must be set MetaTensor.init, init can't be None")
if shape is None: if shape is None:
shape = self.shape shape = self.shape
@ -428,15 +431,28 @@ class MetaTensor(MetaTensor_):
msg = "Error shape={}".format(shape) msg = "Error shape={}".format(shape)
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
from .seed import get_seed class seed_context:
global_seed = get_seed() '''set and restore seed'''
need_set_seed = ((slice_index is not None) and (global_seed is None)) def __init__(self, init):
seed_saved = np.random.get_state()[1][0] self.init = init
if need_set_seed: from .seed import get_seed
np.random.seed(slice_index) global_seed = get_seed()
self.init(arr) self._np_seed = np.random.get_state()[1][0]
if need_set_seed: self.need_set_seed = ((slice_index is not None) and (global_seed is None))
np.random.seed(seed_saved) self.seed = self.init.seed
def __enter__(self):
if self.need_set_seed:
np.random.seed(slice_index)
self.init.seed = slice_index
def __exit__(self, ptype, value, trace):
if self.need_set_seed:
np.random.seed(self._np_seed)
self.init.seed = self.seed
with seed_context(self.init):
self.init(arr)
return Tensor(arr, dtype=self.dtype) return Tensor(arr, dtype=self.dtype)

@ -24,6 +24,7 @@ from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam from mindspore.nn.optim import Adam
from mindspore.common import set_seed
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
@ -95,12 +96,12 @@ def do_sparse_embedding(ps=False):
envs = os.environ envs = os.environ
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(0) set_seed(0)
ps_loss = do_sparse_embedding(True) ps_loss = do_sparse_embedding(True)
if _is_role_worker(): if _is_role_worker():
context.reset_ps_context() context.reset_ps_context()
np.random.seed(0) set_seed(0)
no_ps_loss = do_sparse_embedding() no_ps_loss = do_sparse_embedding()
context.set_ps_context(enable_ps=True) context.set_ps_context(enable_ps=True)

@ -14,7 +14,7 @@
import numpy as np import numpy as np
from numpy import allclose from numpy import allclose
from mindspore.common import set_seed
import mindspore.common.initializer as init import mindspore.common.initializer as init
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter from mindspore import Parameter
@ -40,10 +40,10 @@ class ParameterNet(nn.Cell):
def test_using_same_seed_for_initializer(): def test_using_same_seed_for_initializer():
np.random.seed(0) set_seed(0)
net1 = ParameterNet() net1 = ParameterNet()
net1.init_parameters_data() net1.init_parameters_data()
np.random.seed(0) set_seed(0)
net2 = ParameterNet() net2 = ParameterNet()
net2.init_parameters_data() net2.init_parameters_data()
for key in net1.parameters_dict(): for key in net1.parameters_dict():

Loading…
Cancel
Save