|
|
|
@ -420,5 +420,25 @@ class TestMSRAInitializer(unittest.TestCase):
|
|
|
|
|
self.assertEqual(init_op.type, 'assign_value')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestNumpyArrayInitializer(unittest.TestCase):
|
|
|
|
|
def test_numpy_array_initializer(self):
|
|
|
|
|
"""Test the numpy array initializer with supplied arguments
|
|
|
|
|
"""
|
|
|
|
|
import numpy
|
|
|
|
|
program = framework.Program()
|
|
|
|
|
block = program.global_block()
|
|
|
|
|
for _ in range(2):
|
|
|
|
|
np_array = numpy.array([1, 2, 3, 4]).astype('float32')
|
|
|
|
|
block.create_parameter(
|
|
|
|
|
dtype=np_array.dtype,
|
|
|
|
|
shape=np_array.shape,
|
|
|
|
|
lod_level=0,
|
|
|
|
|
name="param",
|
|
|
|
|
initializer=initializer.NumpyArrayInitializer(np_array))
|
|
|
|
|
self.assertEqual(len(block.ops), 1)
|
|
|
|
|
init_op = block.ops[0]
|
|
|
|
|
self.assertEqual(init_op.type, 'assign_value')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|