|
|
@ -338,12 +338,6 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|
|
|
"the variable shape {}.".format(list(init.shape()), shape))
|
|
|
|
"the variable shape {}.".format(list(init.shape()), shape))
|
|
|
|
return init
|
|
|
|
return init
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(init, str):
|
|
|
|
|
|
|
|
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
|
|
|
|
|
|
|
if init_obj is None:
|
|
|
|
|
|
|
|
raise ValueError("The class corresponding to '{}' was not found.".format(init))
|
|
|
|
|
|
|
|
init = init_obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(shape, list):
|
|
|
|
if isinstance(shape, list):
|
|
|
|
shape = tuple(shape)
|
|
|
|
shape = tuple(shape)
|
|
|
|
elif isinstance(shape, numbers.Number):
|
|
|
|
elif isinstance(shape, numbers.Number):
|
|
|
@ -354,6 +348,15 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|
|
|
raise ValueError("Error shape={}".format(shape))
|
|
|
|
raise ValueError("Error shape={}".format(shape))
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(init, Initializer):
|
|
|
|
if isinstance(init, Initializer):
|
|
|
|
|
|
|
|
init.shape = init.shape if init.shape is not None else shape
|
|
|
|
|
|
|
|
init.dtype = init.dtype if init.dtype is not None else dtype
|
|
|
|
|
|
|
|
return init
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(init, str):
|
|
|
|
|
|
|
|
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
|
|
|
|
|
|
|
if init_obj is None:
|
|
|
|
|
|
|
|
raise ValueError("The class corresponding to '{}' was not found.".format(init))
|
|
|
|
|
|
|
|
init = init_obj
|
|
|
|
init.shape = shape
|
|
|
|
init.shape = shape
|
|
|
|
init.dtype = dtype
|
|
|
|
init.dtype = dtype
|
|
|
|
return init
|
|
|
|
return init
|
|
|
|