|
|
|
@ -197,7 +197,24 @@ class BaseRNN(object):
|
|
|
|
|
return numpy.array([o.mean() for o in outs.itervalues()]).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSimpleMul(unittest.TestCase):
|
|
|
|
|
class SeedFixedTestCase(unittest.TestCase):
|
|
|
|
|
@classmethod
|
|
|
|
|
def setUpClass(cls):
|
|
|
|
|
"""Fix random seeds to remove randomness from tests"""
|
|
|
|
|
cls._np_rand_state = numpy.random.get_state()
|
|
|
|
|
cls._py_rand_state = random.getstate()
|
|
|
|
|
|
|
|
|
|
numpy.random.seed(123)
|
|
|
|
|
random.seed(124)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def tearDownClass(cls):
|
|
|
|
|
"""Restore random seeds"""
|
|
|
|
|
numpy.random.set_state(cls._np_rand_state)
|
|
|
|
|
random.setstate(cls._py_rand_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSimpleMul(SeedFixedTestCase):
|
|
|
|
|
DATA_NAME = 'X'
|
|
|
|
|
DATA_WIDTH = 32
|
|
|
|
|
PARAM_NAME = 'W'
|
|
|
|
@ -263,7 +280,7 @@ class TestSimpleMul(unittest.TestCase):
|
|
|
|
|
self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSimpleMulWithMemory(unittest.TestCase):
|
|
|
|
|
class TestSimpleMulWithMemory(SeedFixedTestCase):
|
|
|
|
|
DATA_WIDTH = 32
|
|
|
|
|
HIDDEN_WIDTH = 20
|
|
|
|
|
DATA_NAME = 'X'
|
|
|
|
|