|
|
|
@ -41,7 +41,6 @@ class Initializer:
|
|
|
|
|
self._kwargs = kwargs
|
|
|
|
|
self.shape = None
|
|
|
|
|
self.dtype = None
|
|
|
|
|
self._seed = None
|
|
|
|
|
|
|
|
|
|
def _initialize(self, *kwargs):
|
|
|
|
|
raise NotImplementedError('Must be overridden!')
|
|
|
|
@ -49,15 +48,6 @@ class Initializer:
|
|
|
|
|
def __call__(self, arr):
|
|
|
|
|
return self._initialize(arr)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def seed(self):
|
|
|
|
|
return self._seed
|
|
|
|
|
|
|
|
|
|
@seed.setter
|
|
|
|
|
def seed(self, seed_):
|
|
|
|
|
"""set the random seed."""
|
|
|
|
|
self._seed = seed_
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def shape(self):
|
|
|
|
|
return self._shape
|
|
|
|
@ -74,8 +64,15 @@ class Initializer:
|
|
|
|
|
def dtype(self, dtype):
|
|
|
|
|
self._dtype = dtype
|
|
|
|
|
|
|
|
|
|
def to_tensor(self):
|
|
|
|
|
"""Get the tensor format data of this Initializer."""
|
|
|
|
|
def to_tensor(self, slice_index=None):
|
|
|
|
|
"""
|
|
|
|
|
Get the tensor format data of this Initializer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
slice_index (int): Slice index of a parameter's slices.
|
|
|
|
|
Used when initialize a slice of a parameter, it guarantee that
|
|
|
|
|
devices use the same slice can generate the same tensor.
|
|
|
|
|
"""
|
|
|
|
|
arr = None
|
|
|
|
|
try:
|
|
|
|
|
arr = np.ndarray(self.shape)
|
|
|
|
@ -83,10 +80,10 @@ class Initializer:
|
|
|
|
|
msg = "Error shape={}".format(self.shape)
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
if self._seed is not None:
|
|
|
|
|
np.random.seed(self.seed)
|
|
|
|
|
|
|
|
|
|
if slice_index is not None:
|
|
|
|
|
np.random.seed(slice_index)
|
|
|
|
|
self.__call__(arr)
|
|
|
|
|
self._seed = None
|
|
|
|
|
return Tensor(arr, dtype=self.dtype)
|
|
|
|
|
|
|
|
|
|
def _register(*aliases):
|
|
|
|
|