|
|
|
@ -51,7 +51,7 @@ class Parameters(object):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.__param_conf__ = dict()
|
|
|
|
|
self.__gradient_machines__ = []
|
|
|
|
|
self.__tmp_params__ = []
|
|
|
|
|
self.__tmp_params__ = dict()
|
|
|
|
|
|
|
|
|
|
def __append_config__(self, param_conf):
|
|
|
|
|
"""
|
|
|
|
@ -128,13 +128,10 @@ class Parameters(object):
|
|
|
|
|
|
|
|
|
|
if len(self.__gradient_machines__) == 0:
|
|
|
|
|
# create new parameter in python numpy.
|
|
|
|
|
if len(self.__tmp_params__) != 0:
|
|
|
|
|
ret_list = [
|
|
|
|
|
mat for name, mat in self.__tmp_params__ if name == key
|
|
|
|
|
]
|
|
|
|
|
if len(ret_list) == 1:
|
|
|
|
|
return ret_list[0]
|
|
|
|
|
return np.ndarray(shape=shape, dtype=np.float32)
|
|
|
|
|
if key in self.__tmp_params__:
|
|
|
|
|
return self.__tmp_params__[key]
|
|
|
|
|
else:
|
|
|
|
|
return np.ndarray(shape=shape, dtype=np.float32)
|
|
|
|
|
else:
|
|
|
|
|
for each_gradient_machine in self.__gradient_machines__:
|
|
|
|
|
param = __get_parameter_in_gradient_machine__(
|
|
|
|
@ -187,7 +184,7 @@ class Parameters(object):
|
|
|
|
|
(shape, value.shape))
|
|
|
|
|
|
|
|
|
|
if len(self.__gradient_machines__) == 0:
|
|
|
|
|
self.__tmp_params__.append((key, value))
|
|
|
|
|
self.__tmp_params__[key] = value
|
|
|
|
|
else:
|
|
|
|
|
for each_gradient_machine in self.__gradient_machines__:
|
|
|
|
|
__copy_parameter_to_gradient_machine__(each_gradient_machine,
|
|
|
|
@ -231,7 +228,7 @@ class Parameters(object):
|
|
|
|
|
raise ValueError("gradient_machine should be api.GradientMachine")
|
|
|
|
|
|
|
|
|
|
if len(self.__tmp_params__) != 0:
|
|
|
|
|
for name, val in self.__tmp_params__:
|
|
|
|
|
for name, val in self.__tmp_params__.iteritems():
|
|
|
|
|
try:
|
|
|
|
|
__copy_parameter_to_gradient_machine__(gradient_machine,
|
|
|
|
|
name, val)
|
|
|
|
@ -302,6 +299,12 @@ class Parameters(object):
|
|
|
|
|
params.deserialize(param_name, f)
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
def init_from_tar(self, f):
|
|
|
|
|
tar_param = self.from_tar(f)
|
|
|
|
|
for pname in tar_param.names():
|
|
|
|
|
if pname in self.names():
|
|
|
|
|
self.set(pname, tar_param.get(pname))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __get_parameter_in_gradient_machine__(gradient_machine, name):
|
|
|
|
|
"""
|
|
|
|
|