|
|
|
@ -26,6 +26,10 @@ def create(*topologies):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Parameters(object):
|
|
|
|
|
"""
|
|
|
|
|
The parameters
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.__param_conf__ = dict()
|
|
|
|
|
self.__gradient_machines__ = []
|
|
|
|
@ -66,7 +70,8 @@ class Parameters(object):
|
|
|
|
|
assert isinstance(param, api.Parameter)
|
|
|
|
|
val = param.getBuf(api.PARAMETER_VALUE)
|
|
|
|
|
assert isinstance(val, api.Vector)
|
|
|
|
|
return val.copyToNumpyArray().reshape(shape=shape)
|
|
|
|
|
val = val.copyToNumpyArray()
|
|
|
|
|
return val
|
|
|
|
|
# else continue
|
|
|
|
|
|
|
|
|
|
raise RuntimeError("Unexpected branch")
|
|
|
|
@ -96,6 +101,12 @@ class Parameters(object):
|
|
|
|
|
__copy_parameter_to_gradient_machine__(each_gradient_machine,
|
|
|
|
|
key, value)
|
|
|
|
|
|
|
|
|
|
def get(self, parameter_name):
|
|
|
|
|
return self.__getitem__(key=parameter_name)
|
|
|
|
|
|
|
|
|
|
def set(self, parameter_name, value):
|
|
|
|
|
self.__setitem__(key=parameter_name, value=value)
|
|
|
|
|
|
|
|
|
|
def append_gradient_machine(self, gradient_machine):
|
|
|
|
|
if not isinstance(gradient_machine, api.GradientMachine):
|
|
|
|
|
raise ValueError("gradient_machine should be api.GradientMachine")
|
|
|
|
@ -108,6 +119,7 @@ class Parameters(object):
|
|
|
|
|
except ValueError:
|
|
|
|
|
# If no such parameter in gradient machine, then don't copy
|
|
|
|
|
pass
|
|
|
|
|
self.__gradient_machines__.append(gradient_machine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __get_parameter_in_gradient_machine__(gradient_machine, name):
|
|
|
|
|