avx_docs
commit
48031dd443
@ -0,0 +1,60 @@
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
try:
|
||||
import py_paddle
|
||||
|
||||
del py_paddle
|
||||
except ImportError:
|
||||
print >> sys.stderr, "It seems swig of Paddle is not installed, this " \
|
||||
"unittest will not be run."
|
||||
sys.exit(0)
|
||||
|
||||
import paddle.v2.parameters as parameters
|
||||
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
|
||||
import random
|
||||
import cStringIO
|
||||
import numpy
|
||||
|
||||
|
||||
def __rand_param_config__(name):
|
||||
conf = ParameterConfig()
|
||||
conf.name = name
|
||||
size = 1
|
||||
for i in xrange(2):
|
||||
dim = random.randint(1, 1000)
|
||||
conf.dims.append(dim)
|
||||
size *= dim
|
||||
conf.size = size
|
||||
assert conf.IsInitialized()
|
||||
return conf
|
||||
|
||||
|
||||
class TestParameters(unittest.TestCase):
|
||||
def test_serialization(self):
|
||||
params = parameters.Parameters()
|
||||
params.__append_config__(__rand_param_config__("param_0"))
|
||||
params.__append_config__(__rand_param_config__("param_1"))
|
||||
|
||||
for name in params.names():
|
||||
param = params.get(name)
|
||||
param[:] = numpy.random.uniform(
|
||||
-1.0, 1.0, size=params.get_shape(name))
|
||||
params.set(name, param)
|
||||
|
||||
tmp_file = cStringIO.StringIO()
|
||||
params.to_tar(tmp_file)
|
||||
tmp_file.seek(0)
|
||||
params_dup = parameters.Parameters.from_tar(tmp_file)
|
||||
|
||||
self.assertEqual(params_dup.names(), params.names())
|
||||
|
||||
for name in params.names():
|
||||
self.assertEqual(params.get_shape(name), params_dup.get_shape(name))
|
||||
p0 = params.get(name)
|
||||
p1 = params_dup.get(name)
|
||||
self.assertTrue(numpy.isclose(p0, p1).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue