delete attribute seed of Initializer

pull/2299/head
Yi Huaijie 5 years ago
parent 60de9089ba
commit eae69a386a

@ -41,7 +41,6 @@ class Initializer:
self._kwargs = kwargs self._kwargs = kwargs
self.shape = None self.shape = None
self.dtype = None self.dtype = None
self._seed = None
def _initialize(self, *kwargs): def _initialize(self, *kwargs):
raise NotImplementedError('Must be overridden!') raise NotImplementedError('Must be overridden!')
@ -49,15 +48,6 @@ class Initializer:
def __call__(self, arr): def __call__(self, arr):
return self._initialize(arr) return self._initialize(arr)
@property
def seed(self):
return self._seed
@seed.setter
def seed(self, seed_):
"""set the random seed."""
self._seed = seed_
@property @property
def shape(self): def shape(self):
return self._shape return self._shape
@ -74,8 +64,15 @@ class Initializer:
def dtype(self, dtype): def dtype(self, dtype):
self._dtype = dtype self._dtype = dtype
def to_tensor(self): def to_tensor(self, slice_index=None):
"""Get the tensor format data of this Initializer.""" """
Get the tensor format data of this Initializer.
Args:
slice_index (int): Slice index of a parameter's slices.
Used when initialize a slice of a parameter, it guarantee that
devices use the same slice can generate the same tensor.
"""
arr = None arr = None
try: try:
arr = np.ndarray(self.shape) arr = np.ndarray(self.shape)
@ -83,10 +80,10 @@ class Initializer:
msg = "Error shape={}".format(self.shape) msg = "Error shape={}".format(self.shape)
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
if self._seed is not None:
np.random.seed(self.seed) if slice_index is not None:
np.random.seed(slice_index)
self.__call__(arr) self.__call__(arr)
self._seed = None
return Tensor(arr, dtype=self.dtype) return Tensor(arr, dtype=self.dtype)
def _register(*aliases): def _register(*aliases):

@ -22,7 +22,7 @@ from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular from .._checkparam import _check_str_by_regular
from ..parallel._utils import _set_clone_info, _CloneInfo from ..parallel._utils import _set_clone_info, _CloneInfo
from ..parallel._tensor import _get_seed from ..parallel._tensor import _get_slice_index
__all__ = ['Parameter', 'ParameterTuple'] __all__ = ['Parameter', 'ParameterTuple']
@ -250,9 +250,11 @@ class Parameter:
raise ValueError("The length of layout must be 3! layout is {}." raise ValueError("The length of layout must be 3! layout is {}."
.format(layout)) .format(layout))
self.init_mode.shape = layout[2] self.init_mode.shape = layout[2]
self.init_mode.seed = int(_get_seed(layout[0], layout[1])) slice_index = int(_get_slice_index(layout[0], layout[1]))
self.default_input = self.init_mode.to_tensor(slice_index)
else:
self.default_input = self.init_mode.to_tensor()
self.default_input = self.init_mode.to_tensor()
self.init_mode = None self.init_mode = None
if set_sliced: if set_sliced:
self.sliced = True self.sliced = True

@ -168,21 +168,21 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
raise ValueError("The length of np_tensor does not match the length of strategy!") raise ValueError("The length of np_tensor does not match the length of strategy!")
return _chunk_tensor(np_tensor, strategy, len(strategy)) return _chunk_tensor(np_tensor, strategy, len(strategy))
def _get_seed(dev_mat, tensor_map): def _get_slice_index(dev_mat, tensor_map):
""" """
Get the random seed for current slice. Get the slice index for current slice.
Args: Args:
dev_mat (list): The device matrix of devices. dev_mat (list): The device matrix of devices.
tensor_map (list): The split strategy of tensor. tensor_map (list): The split strategy of tensor.
Returns: Returns:
Integer, the local random seed for this device. Integer, the slice index for slice on this device.
""" """
rank = get_rank() rank = get_rank()
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
return tensor_slice_seed return tensor_slice_index
def _load_tensor(tensor, dev_mat, tensor_map): def _load_tensor(tensor, dev_mat, tensor_map):
""" """

Loading…
Cancel
Save