Merge pull request #1520 from reyoung/feature/serialize_deserialize_in_parameters
Add save/load parameters.avx_docs
commit
963bd5d5ea
@ -0,0 +1,60 @@
|
|||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
import py_paddle
|
||||||
|
|
||||||
|
del py_paddle
|
||||||
|
except ImportError:
|
||||||
|
print >> sys.stderr, "It seems swig of Paddle is not installed, this " \
|
||||||
|
"unittest will not be run."
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
import paddle.v2.parameters as parameters
|
||||||
|
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
|
||||||
|
import random
|
||||||
|
import cStringIO
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
def __rand_param_config__(name):
|
||||||
|
conf = ParameterConfig()
|
||||||
|
conf.name = name
|
||||||
|
size = 1
|
||||||
|
for i in xrange(2):
|
||||||
|
dim = random.randint(1, 1000)
|
||||||
|
conf.dims.append(dim)
|
||||||
|
size *= dim
|
||||||
|
conf.size = size
|
||||||
|
assert conf.IsInitialized()
|
||||||
|
return conf
|
||||||
|
|
||||||
|
|
||||||
|
class TestParameters(unittest.TestCase):
|
||||||
|
def test_serialization(self):
|
||||||
|
params = parameters.Parameters()
|
||||||
|
params.__append_config__(__rand_param_config__("param_0"))
|
||||||
|
params.__append_config__(__rand_param_config__("param_1"))
|
||||||
|
|
||||||
|
for name in params.names():
|
||||||
|
param = params.get(name)
|
||||||
|
param[:] = numpy.random.uniform(
|
||||||
|
-1.0, 1.0, size=params.get_shape(name))
|
||||||
|
params.set(name, param)
|
||||||
|
|
||||||
|
tmp_file = cStringIO.StringIO()
|
||||||
|
params.to_tar(tmp_file)
|
||||||
|
tmp_file.seek(0)
|
||||||
|
params_dup = parameters.Parameters.from_tar(tmp_file)
|
||||||
|
|
||||||
|
self.assertEqual(params_dup.names(), params.names())
|
||||||
|
|
||||||
|
for name in params.names():
|
||||||
|
self.assertEqual(params.get_shape(name), params_dup.get_shape(name))
|
||||||
|
p0 = params.get(name)
|
||||||
|
p1 = params_dup.get(name)
|
||||||
|
self.assertTrue(numpy.isclose(p0, p1).all())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue