Merge pull request #4684 from reyoung/feature/parameter
Feature/parameterrevert-4814-Add_sequence_project_op
commit
ee22a436a8
@ -0,0 +1,27 @@
|
|||||||
|
import unittest
|
||||||
|
from paddle.v2.framework.graph import g_program
|
||||||
|
import paddle.v2.framework.core as core
|
||||||
|
|
||||||
|
|
||||||
|
class TestParameter(unittest.TestCase):
|
||||||
|
def test_param(self):
|
||||||
|
b = g_program.create_block()
|
||||||
|
param = b.create_parameter(
|
||||||
|
name='fc.w',
|
||||||
|
shape=[784, 100],
|
||||||
|
dtype='float32',
|
||||||
|
initialize_attr={
|
||||||
|
'type': 'uniform_random',
|
||||||
|
'seed': 13,
|
||||||
|
'min': -5.0,
|
||||||
|
'max': 5.0
|
||||||
|
})
|
||||||
|
self.assertIsNotNone(param)
|
||||||
|
self.assertEqual('fc.w', param.name)
|
||||||
|
self.assertEqual((784, 100), param.shape)
|
||||||
|
self.assertEqual(core.DataType.FP32, param.data_type)
|
||||||
|
self.assertEqual(0, param.block.idx)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,40 @@
|
|||||||
|
import unittest
|
||||||
|
from paddle.v2.framework.graph import Variable, g_program
|
||||||
|
import paddle.v2.framework.core as core
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestVariable(unittest.TestCase):
|
||||||
|
def test_np_dtype_convert(self):
|
||||||
|
DT = core.DataType
|
||||||
|
convert = Variable._convert_np_dtype_to_dtype_
|
||||||
|
self.assertEqual(DT.FP32, convert(np.float32))
|
||||||
|
self.assertEqual(DT.FP16, convert("float16"))
|
||||||
|
self.assertEqual(DT.FP64, convert("float64"))
|
||||||
|
self.assertEqual(DT.INT32, convert("int32"))
|
||||||
|
self.assertEqual(DT.INT16, convert("int16"))
|
||||||
|
self.assertEqual(DT.INT64, convert("int64"))
|
||||||
|
self.assertEqual(DT.BOOL, convert("bool"))
|
||||||
|
self.assertRaises(ValueError, lambda: convert("int8"))
|
||||||
|
|
||||||
|
def test_var(self):
|
||||||
|
b = g_program.current_block()
|
||||||
|
w = b.create_var(
|
||||||
|
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
|
||||||
|
self.assertEqual(core.DataType.FP64, w.data_type)
|
||||||
|
self.assertEqual((784, 100), w.shape)
|
||||||
|
self.assertEqual("fc.w", w.name)
|
||||||
|
self.assertEqual(0, w.lod_level)
|
||||||
|
|
||||||
|
w = b.create_var(name='fc.w')
|
||||||
|
self.assertEqual(core.DataType.FP64, w.data_type)
|
||||||
|
self.assertEqual((784, 100), w.shape)
|
||||||
|
self.assertEqual("fc.w", w.name)
|
||||||
|
self.assertEqual(0, w.lod_level)
|
||||||
|
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
lambda: b.create_var(name="fc.w", shape=(24, 100)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue