|
|
|
@ -51,9 +51,9 @@ def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps
|
|
|
|
|
return lr_each_step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_net_param(net, init_value='ones'):
|
|
|
|
|
"""Init:wq the parameters in net."""
|
|
|
|
|
params = net.trainable_params()
|
|
|
|
|
def init_net_param(network, init_value='ones'):
|
|
|
|
|
"""Init:wq the parameters in network."""
|
|
|
|
|
params = network.trainable_params()
|
|
|
|
|
for p in params:
|
|
|
|
|
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
|
|
|
|
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))
|
|
|
|
|