|
|
|
@ -1,7 +1,9 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
|
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
|
|
|
|
|
|
|
|
|
|
import struct
|
|
|
|
|
import tarfile
|
|
|
|
|
import cStringIO
|
|
|
|
|
from topology import Topology
|
|
|
|
|
|
|
|
|
|
__all__ = ['Parameters', 'create']
|
|
|
|
@ -235,6 +237,42 @@ class Parameters(object):
|
|
|
|
|
|
|
|
|
|
return {'conf': param_conf, 'params': params}
|
|
|
|
|
|
|
|
|
|
def serialize(self, name, f):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
:param f:
|
|
|
|
|
:type f: file
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
param = self.get(name)
|
|
|
|
|
size = reduce(lambda a, b: a * b, param.shape)
|
|
|
|
|
f.write(struct.pack("IIQ", 0, 4, size))
|
|
|
|
|
param = param.astype(np.float32)
|
|
|
|
|
f.write(param.tobytes())
|
|
|
|
|
|
|
|
|
|
def deserialize(self, name, f):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
:param f:
|
|
|
|
|
:type f: file
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
f.read(16) # header
|
|
|
|
|
arr = np.fromfile(f, dtype=np.float32)
|
|
|
|
|
self.set(name, arr.reshape(self.get_shape(name)))
|
|
|
|
|
|
|
|
|
|
def serialize_to_tar(self, f):
|
|
|
|
|
tar = tarfile.TarFile(fileobj=f, mode='w')
|
|
|
|
|
for nm in self.names():
|
|
|
|
|
buf = cStringIO.StringIO()
|
|
|
|
|
self.serialize(nm, buf)
|
|
|
|
|
tarinfo = tarfile.TarInfo(name=nm)
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
tarinfo.size = len(buf.getvalue())
|
|
|
|
|
tar.addfile(tarinfo, buf)
|
|
|
|
|
|
|
|
|
|
def __setstate__(self, obj):
|
|
|
|
|
Parameters.__init__(self)
|
|
|
|
|
|
|
|
|
|