|
|
|
@ -3,7 +3,7 @@ import paddle.v2.framework.create_op_creation_methods as creation
|
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
|
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
|
|
|
|
|
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
|
|
|
|
|
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
|
|
|
|
|
import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGetAllProtos(unittest.TestCase):
|
|
|
|
@ -76,7 +76,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
|
expected1.type = 'fc'
|
|
|
|
|
attr = expected1.attrs.add()
|
|
|
|
|
attr.name = 'input_format'
|
|
|
|
|
attr.type = attr_type_pb2.INTS
|
|
|
|
|
attr.type = attribute_pb2.INTS
|
|
|
|
|
attr.ints.extend([0, 1, 2, 3])
|
|
|
|
|
self.assertEqual(expected1, generated1)
|
|
|
|
|
|
|
|
|
@ -88,7 +88,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
|
expected2.type = 'fc'
|
|
|
|
|
attr = expected2.attrs.add()
|
|
|
|
|
attr.name = 'input_format'
|
|
|
|
|
attr.type = attr_type_pb2.INTS
|
|
|
|
|
attr.type = attribute_pb2.INTS
|
|
|
|
|
attr.ints.extend([0, 3, 6, 7])
|
|
|
|
|
self.assertEqual(expected2, generated2)
|
|
|
|
|
|
|
|
|
@ -105,12 +105,12 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
|
attr.comment = ""
|
|
|
|
|
attr.type = type
|
|
|
|
|
|
|
|
|
|
__add_attr__("int_attr", attr_type_pb2.INT)
|
|
|
|
|
__add_attr__("float_attr", attr_type_pb2.FLOAT)
|
|
|
|
|
__add_attr__("string_attr", attr_type_pb2.STRING)
|
|
|
|
|
__add_attr__("ints_attr", attr_type_pb2.INTS)
|
|
|
|
|
__add_attr__("floats_attr", attr_type_pb2.FLOATS)
|
|
|
|
|
__add_attr__("strings_attr", attr_type_pb2.STRINGS)
|
|
|
|
|
__add_attr__("int_attr", attribute_pb2.INT)
|
|
|
|
|
__add_attr__("float_attr", attribute_pb2.FLOAT)
|
|
|
|
|
__add_attr__("string_attr", attribute_pb2.STRING)
|
|
|
|
|
__add_attr__("ints_attr", attribute_pb2.INTS)
|
|
|
|
|
__add_attr__("floats_attr", attribute_pb2.FLOATS)
|
|
|
|
|
__add_attr__("strings_attr", attribute_pb2.STRINGS)
|
|
|
|
|
|
|
|
|
|
op.comment = ""
|
|
|
|
|
self.assertTrue(op.IsInitialized())
|
|
|
|
@ -131,32 +131,32 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
|
expected.inputs.extend(['a'])
|
|
|
|
|
attr = expected.attrs.add()
|
|
|
|
|
attr.name = "int_attr"
|
|
|
|
|
attr.type = attr_type_pb2.INT
|
|
|
|
|
attr.type = attribute_pb2.INT
|
|
|
|
|
attr.i = 10
|
|
|
|
|
|
|
|
|
|
attr = expected.attrs.add()
|
|
|
|
|
attr.name = "float_attr"
|
|
|
|
|
attr.type = attr_type_pb2.FLOAT
|
|
|
|
|
attr.type = attribute_pb2.FLOAT
|
|
|
|
|
attr.f = 3.2
|
|
|
|
|
|
|
|
|
|
attr = expected.attrs.add()
|
|
|
|
|
attr.name = "string_attr"
|
|
|
|
|
attr.type = attr_type_pb2.STRING
|
|
|
|
|
attr.type = attribute_pb2.STRING
|
|
|
|
|
attr.s = "test_str"
|
|
|
|
|
|
|
|
|
|
attr = expected.attrs.add()
|
|
|
|
|
attr.name = "ints_attr"
|
|
|
|
|
attr.type = attr_type_pb2.INTS
|
|
|
|
|
attr.type = attribute_pb2.INTS
|
|
|
|
|
attr.ints.extend([0, 1, 2, 3, 4])
|
|
|
|
|
|
|
|
|
|
attr = expected.attrs.add()
|
|
|
|
|
attr.name = 'floats_attr'
|
|
|
|
|
attr.type = attr_type_pb2.FLOATS
|
|
|
|
|
attr.type = attribute_pb2.FLOATS
|
|
|
|
|
attr.floats.extend([0.2, 3.2, 4.5])
|
|
|
|
|
|
|
|
|
|
attr = expected.attrs.add()
|
|
|
|
|
attr.name = 'strings_attr'
|
|
|
|
|
attr.type = attr_type_pb2.STRINGS
|
|
|
|
|
attr.type = attribute_pb2.STRINGS
|
|
|
|
|
attr.strings.extend(['a', 'b', 'c'])
|
|
|
|
|
|
|
|
|
|
self.assertEqual(expected, generated)
|
|
|
|
@ -185,7 +185,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
|
desc.type = "test"
|
|
|
|
|
attr = desc.attrs.add()
|
|
|
|
|
attr.name = "temporary_index"
|
|
|
|
|
attr.type = attr_type_pb2.INTS
|
|
|
|
|
attr.type = attribute_pb2.INTS
|
|
|
|
|
attr.ints.append(2)
|
|
|
|
|
self.assertEqual(generated, desc)
|
|
|
|
|
|
|
|
|
@ -219,7 +219,7 @@ This op is used for unit test, not a real op.
|
|
|
|
|
|
|
|
|
|
test_str = op.attrs.add()
|
|
|
|
|
test_str.name = "str_attr"
|
|
|
|
|
test_str.type = attr_type_pb2.STRING
|
|
|
|
|
test_str.type = attribute_pb2.STRING
|
|
|
|
|
test_str.comment = "A string attribute for test op"
|
|
|
|
|
|
|
|
|
|
actual = creation.get_docstring_from_op_proto(op)
|
|
|
|
|