|
|
@ -19,7 +19,7 @@ import math
|
|
|
|
from functools import reduce
|
|
|
|
from functools import reduce
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from scipy.stats import truncnorm
|
|
|
|
from scipy.stats import truncnorm
|
|
|
|
|
|
|
|
from .seed import _get_graph_seed
|
|
|
|
from . import dtype as mstype
|
|
|
|
from . import dtype as mstype
|
|
|
|
from .tensor import Tensor, MetaTensor
|
|
|
|
from .tensor import Tensor, MetaTensor
|
|
|
|
from .._c_expression import random_normal
|
|
|
|
from .._c_expression import random_normal
|
|
|
@ -40,8 +40,19 @@ class Initializer:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
self._kwargs = kwargs
|
|
|
|
self._kwargs = kwargs
|
|
|
|
self.shape = None
|
|
|
|
self._seed = None
|
|
|
|
self.dtype = None
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def seed(self):
|
|
|
|
|
|
|
|
seed_ = self._seed if self._seed is not None else 1
|
|
|
|
|
|
|
|
_, seed = _get_graph_seed(seed_, "init")
|
|
|
|
|
|
|
|
return seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@seed.setter
|
|
|
|
|
|
|
|
def seed(self, value):
|
|
|
|
|
|
|
|
if not isinstance(value, int):
|
|
|
|
|
|
|
|
raise TypeError("'value' must be int type.")
|
|
|
|
|
|
|
|
self._seed = value
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize(self, *kwargs):
|
|
|
|
def _initialize(self, *kwargs):
|
|
|
|
raise NotImplementedError('Must be overridden!')
|
|
|
|
raise NotImplementedError('Must be overridden!')
|
|
|
@ -353,7 +364,7 @@ class Normal(Initializer):
|
|
|
|
self.sigma = sigma
|
|
|
|
self.sigma = sigma
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize(self, arr):
|
|
|
|
def _initialize(self, arr):
|
|
|
|
seed = np.random.get_state()[1][0]
|
|
|
|
seed = self.seed
|
|
|
|
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
|
|
|
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
|
|
|
random_normal(0, self.sigma, arr.shape, seed, output_tensor)
|
|
|
|
random_normal(0, self.sigma, arr.shape, seed, output_tensor)
|
|
|
|
output_data = output_tensor.asnumpy()
|
|
|
|
output_data = output_tensor.asnumpy()
|
|
|
@ -434,8 +445,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|
|
|
elif isinstance(init, numbers.Number):
|
|
|
|
elif isinstance(init, numbers.Number):
|
|
|
|
init = Constant(init)
|
|
|
|
init = Constant(init)
|
|
|
|
shape = shape if shape is not None else init.shape
|
|
|
|
shape = shape if shape is not None else init.shape
|
|
|
|
dtype = init.dtype if init.dtype is not None else dtype
|
|
|
|
init_obj = MetaTensor(dtype, shape, init)
|
|
|
|
init_obj = MetaTensor(init, dtype, shape)
|
|
|
|
|
|
|
|
return init_obj
|
|
|
|
return init_obj
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
__all__ = [
|
|
|
|