|
|
|
@ -14,16 +14,17 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Parameters utils"""
|
|
|
|
|
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore.common.initializer import initializer, TruncatedNormal
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
|
|
|
|
"""Init the parameters in net."""
|
|
|
|
|
params = network.trainable_params()
|
|
|
|
|
for p in params:
|
|
|
|
|
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
|
|
|
|
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
|
|
|
|
np.random.seed(seed=1)
|
|
|
|
|
if initialize_mode == 'TruncatedNormal':
|
|
|
|
|
p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape, p.data.dtype))
|
|
|
|
|
p.set_parameter_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
|
|
|
|
|
else:
|
|
|
|
|
p.set_parameter_data(initialize_mode, p.data.shape, p.data.dtype)
|
|
|
|
|
|
|
|
|
|