|
|
|
@ -224,19 +224,6 @@ class Parameters(object):
|
|
|
|
|
|
|
|
|
|
self.__gradient_machines__.append(gradient_machine)
|
|
|
|
|
|
|
|
|
|
def __getstate__(self):
|
|
|
|
|
params = {}
|
|
|
|
|
for name in self.names():
|
|
|
|
|
params[name] = self.get(name)
|
|
|
|
|
|
|
|
|
|
param_conf = {}
|
|
|
|
|
for name in self.__param_conf__:
|
|
|
|
|
conf = self.__param_conf__[name]
|
|
|
|
|
assert isinstance(conf, ParameterConfig)
|
|
|
|
|
param_conf[name] = conf.SerializeToString()
|
|
|
|
|
|
|
|
|
|
return {'conf': param_conf, 'params': params}
|
|
|
|
|
|
|
|
|
|
def serialize(self, name, f):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -260,10 +247,10 @@ class Parameters(object):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
f.read(16) # header
|
|
|
|
|
arr = np.fromfile(f, dtype=np.float32)
|
|
|
|
|
arr = np.frombuffer(f.read(), dtype=np.float32)
|
|
|
|
|
self.set(name, arr.reshape(self.get_shape(name)))
|
|
|
|
|
|
|
|
|
|
def serialize_to_tar(self, f):
|
|
|
|
|
def to_tar(self, f):
|
|
|
|
|
tar = tarfile.TarFile(fileobj=f, mode='w')
|
|
|
|
|
for nm in self.names():
|
|
|
|
|
buf = cStringIO.StringIO()
|
|
|
|
@ -273,19 +260,30 @@ class Parameters(object):
|
|
|
|
|
tarinfo.size = len(buf.getvalue())
|
|
|
|
|
tar.addfile(tarinfo, buf)
|
|
|
|
|
|
|
|
|
|
def __setstate__(self, obj):
|
|
|
|
|
Parameters.__init__(self)
|
|
|
|
|
|
|
|
|
|
def __impl__(conf, params):
|
|
|
|
|
for name in conf:
|
|
|
|
|
p = ParameterConfig()
|
|
|
|
|
p.ParseFromString(conf[name])
|
|
|
|
|
self.__append_config__(p)
|
|
|
|
|
for name in params:
|
|
|
|
|
shape = self.get_shape(name)
|
|
|
|
|
self.set(name, params[name].reshape(shape))
|
|
|
|
|
|
|
|
|
|
__impl__(**obj)
|
|
|
|
|
conf = self.__param_conf__[nm]
|
|
|
|
|
confStr = conf.SerializeToString()
|
|
|
|
|
tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm)
|
|
|
|
|
tarinfo.size = len(confStr)
|
|
|
|
|
buf = cStringIO.StringIO(confStr)
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
tar.addfile(tarinfo, fileobj=buf)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_tar(f):
|
|
|
|
|
params = Parameters()
|
|
|
|
|
tar = tarfile.TarFile(fileobj=f, mode='r')
|
|
|
|
|
for finfo in tar:
|
|
|
|
|
assert isinstance(finfo, tarfile.TarInfo)
|
|
|
|
|
if finfo.name.endswith('.protobuf'):
|
|
|
|
|
f = tar.extractfile(finfo)
|
|
|
|
|
conf = ParameterConfig()
|
|
|
|
|
conf.ParseFromString(f.read())
|
|
|
|
|
params.__append_config__(conf)
|
|
|
|
|
|
|
|
|
|
for param_name in params.names():
|
|
|
|
|
f = tar.extractfile(param_name)
|
|
|
|
|
params.deserialize(param_name, f)
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __get_parameter_in_gradient_machine__(gradient_machine, name):
|
|
|
|
|