Merge pull request #4684 from reyoung/feature/parameter
Feature/parameterrevert-4814-Add_sequence_project_op
commit
ee22a436a8
@ -0,0 +1,27 @@
|
||||
import unittest
|
||||
from paddle.v2.framework.graph import g_program
|
||||
import paddle.v2.framework.core as core
|
||||
|
||||
|
||||
class TestParameter(unittest.TestCase):
|
||||
def test_param(self):
|
||||
b = g_program.create_block()
|
||||
param = b.create_parameter(
|
||||
name='fc.w',
|
||||
shape=[784, 100],
|
||||
dtype='float32',
|
||||
initialize_attr={
|
||||
'type': 'uniform_random',
|
||||
'seed': 13,
|
||||
'min': -5.0,
|
||||
'max': 5.0
|
||||
})
|
||||
self.assertIsNotNone(param)
|
||||
self.assertEqual('fc.w', param.name)
|
||||
self.assertEqual((784, 100), param.shape)
|
||||
self.assertEqual(core.DataType.FP32, param.data_type)
|
||||
self.assertEqual(0, param.block.idx)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,40 @@
|
||||
import unittest
|
||||
from paddle.v2.framework.graph import Variable, g_program
|
||||
import paddle.v2.framework.core as core
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestVariable(unittest.TestCase):
|
||||
def test_np_dtype_convert(self):
|
||||
DT = core.DataType
|
||||
convert = Variable._convert_np_dtype_to_dtype_
|
||||
self.assertEqual(DT.FP32, convert(np.float32))
|
||||
self.assertEqual(DT.FP16, convert("float16"))
|
||||
self.assertEqual(DT.FP64, convert("float64"))
|
||||
self.assertEqual(DT.INT32, convert("int32"))
|
||||
self.assertEqual(DT.INT16, convert("int16"))
|
||||
self.assertEqual(DT.INT64, convert("int64"))
|
||||
self.assertEqual(DT.BOOL, convert("bool"))
|
||||
self.assertRaises(ValueError, lambda: convert("int8"))
|
||||
|
||||
def test_var(self):
|
||||
b = g_program.current_block()
|
||||
w = b.create_var(
|
||||
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
|
||||
self.assertEqual(core.DataType.FP64, w.data_type)
|
||||
self.assertEqual((784, 100), w.shape)
|
||||
self.assertEqual("fc.w", w.name)
|
||||
self.assertEqual(0, w.lod_level)
|
||||
|
||||
w = b.create_var(name='fc.w')
|
||||
self.assertEqual(core.DataType.FP64, w.data_type)
|
||||
self.assertEqual((784, 100), w.shape)
|
||||
self.assertEqual("fc.w", w.name)
|
||||
self.assertEqual(0, w.lod_level)
|
||||
|
||||
self.assertRaises(ValueError,
|
||||
lambda: b.create_var(name="fc.w", shape=(24, 100)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue