|
|
|
@ -64,7 +64,7 @@ class Initializer:
|
|
|
|
|
def dtype(self, dtype):
|
|
|
|
|
self._dtype = dtype
|
|
|
|
|
|
|
|
|
|
def to_tensor(self, slice_index=None):
|
|
|
|
|
def to_tensor(self, slice_index=None, shape=None):
|
|
|
|
|
"""
|
|
|
|
|
Get the tensor format data of this Initializer.
|
|
|
|
|
|
|
|
|
@ -72,12 +72,16 @@ class Initializer:
|
|
|
|
|
slice_index (int): Slice index of a parameter's slices.
|
|
|
|
|
Used when initialize a slice of a parameter, it guarantee that
|
|
|
|
|
devices use the same slice can generate the same tensor.
|
|
|
|
|
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
|
|
|
|
|
"""
|
|
|
|
|
arr = None
|
|
|
|
|
if shape is None:
|
|
|
|
|
shape = self.shape
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
arr = np.ndarray(self.shape)
|
|
|
|
|
arr = np.ndarray(shape)
|
|
|
|
|
except ValueError:
|
|
|
|
|
msg = "Error shape={}".format(self.shape)
|
|
|
|
|
msg = "Error shape={}".format(shape)
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|