|
|
@ -1,5 +1,5 @@
|
|
|
|
import unittest
|
|
|
|
import unittest
|
|
|
|
import paddle.v2.framework.create_op_creation_methods as creation
|
|
|
|
import paddle.v2.framework.op as op
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
|
|
|
|
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
|
|
|
|
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
|
|
|
|
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
|
|
|
@ -8,7 +8,7 @@ import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
|
|
|
|
|
|
|
|
|
|
|
|
class TestGetAllProtos(unittest.TestCase):
|
|
|
|
class TestGetAllProtos(unittest.TestCase):
|
|
|
|
def test_all(self):
|
|
|
|
def test_all(self):
|
|
|
|
all_protos = creation.get_all_op_protos()
|
|
|
|
all_protos = op.get_all_op_protos()
|
|
|
|
self.assertNotEqual(0, len(all_protos))
|
|
|
|
self.assertNotEqual(0, len(all_protos))
|
|
|
|
|
|
|
|
|
|
|
|
for each in all_protos:
|
|
|
|
for each in all_protos:
|
|
|
@ -17,25 +17,25 @@ class TestGetAllProtos(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
def test_plain_input_output(self):
|
|
|
|
def test_plain_input_output(self):
|
|
|
|
op = op_proto_pb2.OpProto()
|
|
|
|
op_proto = op_proto_pb2.OpProto()
|
|
|
|
op.type = "test"
|
|
|
|
op_proto.type = "test"
|
|
|
|
ipt = op.inputs.add()
|
|
|
|
ipt = op_proto.inputs.add()
|
|
|
|
ipt.name = "X"
|
|
|
|
ipt.name = "X"
|
|
|
|
ipt.comment = "not matter"
|
|
|
|
ipt.comment = "not matter"
|
|
|
|
|
|
|
|
|
|
|
|
ipt = op.inputs.add()
|
|
|
|
ipt = op_proto.inputs.add()
|
|
|
|
ipt.name = "Y"
|
|
|
|
ipt.name = "Y"
|
|
|
|
ipt.comment = "not matter"
|
|
|
|
ipt.comment = "not matter"
|
|
|
|
|
|
|
|
|
|
|
|
opt = op.outputs.add()
|
|
|
|
opt = op_proto.outputs.add()
|
|
|
|
opt.name = "Z"
|
|
|
|
opt.name = "Z"
|
|
|
|
opt.comment = "not matter"
|
|
|
|
opt.comment = "not matter"
|
|
|
|
|
|
|
|
|
|
|
|
op.comment = "not matter"
|
|
|
|
op_proto.comment = "not matter"
|
|
|
|
|
|
|
|
|
|
|
|
self.assertTrue(op.IsInitialized())
|
|
|
|
self.assertTrue(op_proto.IsInitialized())
|
|
|
|
|
|
|
|
|
|
|
|
method = creation.OpDescCreationMethod(op)
|
|
|
|
method = op.OpDescCreationMethod(op_proto)
|
|
|
|
output = method(X="a", Y="b", Z="c")
|
|
|
|
output = method(X="a", Y="b", Z="c")
|
|
|
|
|
|
|
|
|
|
|
|
expected = op_desc_pb2.OpDesc()
|
|
|
|
expected = op_desc_pb2.OpDesc()
|
|
|
@ -45,29 +45,29 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
self.assertEqual(expected, output)
|
|
|
|
self.assertEqual(expected, output)
|
|
|
|
|
|
|
|
|
|
|
|
def test_multiple_input_plain_output(self):
|
|
|
|
def test_multiple_input_plain_output(self):
|
|
|
|
op = op_proto_pb2.OpProto()
|
|
|
|
op_proto = op_proto_pb2.OpProto()
|
|
|
|
op.type = "fc"
|
|
|
|
op_proto.type = "fc"
|
|
|
|
ipt = op.inputs.add()
|
|
|
|
ipt = op_proto.inputs.add()
|
|
|
|
ipt.name = "X"
|
|
|
|
ipt.name = "X"
|
|
|
|
ipt.comment = ""
|
|
|
|
ipt.comment = ""
|
|
|
|
ipt.multiple = True
|
|
|
|
ipt.multiple = True
|
|
|
|
|
|
|
|
|
|
|
|
ipt = op.inputs.add()
|
|
|
|
ipt = op_proto.inputs.add()
|
|
|
|
ipt.name = "W"
|
|
|
|
ipt.name = "W"
|
|
|
|
ipt.comment = ""
|
|
|
|
ipt.comment = ""
|
|
|
|
ipt.multiple = True
|
|
|
|
ipt.multiple = True
|
|
|
|
|
|
|
|
|
|
|
|
ipt = op.inputs.add()
|
|
|
|
ipt = op_proto.inputs.add()
|
|
|
|
ipt.name = "b"
|
|
|
|
ipt.name = "b"
|
|
|
|
ipt.comment = ""
|
|
|
|
ipt.comment = ""
|
|
|
|
|
|
|
|
|
|
|
|
out = op.outputs.add()
|
|
|
|
out = op_proto.outputs.add()
|
|
|
|
out.name = "Y"
|
|
|
|
out.name = "Y"
|
|
|
|
out.comment = ""
|
|
|
|
out.comment = ""
|
|
|
|
|
|
|
|
|
|
|
|
op.comment = ""
|
|
|
|
op_proto.comment = ""
|
|
|
|
self.assertTrue(op.IsInitialized())
|
|
|
|
self.assertTrue(op_proto.IsInitialized())
|
|
|
|
method = creation.OpDescCreationMethod(op)
|
|
|
|
method = op.OpDescCreationMethod(op_proto)
|
|
|
|
|
|
|
|
|
|
|
|
generated1 = method(X="x", W="w", b="b", Y="y")
|
|
|
|
generated1 = method(X="x", W="w", b="b", Y="y")
|
|
|
|
expected1 = op_desc_pb2.OpDesc()
|
|
|
|
expected1 = op_desc_pb2.OpDesc()
|
|
|
@ -93,14 +93,14 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
self.assertEqual(expected2, generated2)
|
|
|
|
self.assertEqual(expected2, generated2)
|
|
|
|
|
|
|
|
|
|
|
|
def test_attrs(self):
|
|
|
|
def test_attrs(self):
|
|
|
|
op = op_proto_pb2.OpProto()
|
|
|
|
op_proto = op_proto_pb2.OpProto()
|
|
|
|
op.type = "test"
|
|
|
|
op_proto.type = "test"
|
|
|
|
ipt = op.inputs.add()
|
|
|
|
ipt = op_proto.inputs.add()
|
|
|
|
ipt.name = 'X'
|
|
|
|
ipt.name = 'X'
|
|
|
|
ipt.comment = ""
|
|
|
|
ipt.comment = ""
|
|
|
|
|
|
|
|
|
|
|
|
def __add_attr__(name, type):
|
|
|
|
def __add_attr__(name, type):
|
|
|
|
attr = op.attrs.add()
|
|
|
|
attr = op_proto.attrs.add()
|
|
|
|
attr.name = name
|
|
|
|
attr.name = name
|
|
|
|
attr.comment = ""
|
|
|
|
attr.comment = ""
|
|
|
|
attr.type = type
|
|
|
|
attr.type = type
|
|
|
@ -112,10 +112,10 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
__add_attr__("floats_attr", attr_type_pb2.FLOATS)
|
|
|
|
__add_attr__("floats_attr", attr_type_pb2.FLOATS)
|
|
|
|
__add_attr__("strings_attr", attr_type_pb2.STRINGS)
|
|
|
|
__add_attr__("strings_attr", attr_type_pb2.STRINGS)
|
|
|
|
|
|
|
|
|
|
|
|
op.comment = ""
|
|
|
|
op_proto.comment = ""
|
|
|
|
self.assertTrue(op.IsInitialized())
|
|
|
|
self.assertTrue(op_proto.IsInitialized())
|
|
|
|
|
|
|
|
|
|
|
|
method = creation.OpDescCreationMethod(op)
|
|
|
|
method = op.OpDescCreationMethod(op_proto)
|
|
|
|
|
|
|
|
|
|
|
|
generated = method(
|
|
|
|
generated = method(
|
|
|
|
X="a",
|
|
|
|
X="a",
|
|
|
@ -162,23 +162,23 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
self.assertEqual(expected, generated)
|
|
|
|
self.assertEqual(expected, generated)
|
|
|
|
|
|
|
|
|
|
|
|
def test_input_temporary_output(self):
|
|
|
|
def test_input_temporary_output(self):
|
|
|
|
op = op_proto_pb2.OpProto()
|
|
|
|
op_proto = op_proto_pb2.OpProto()
|
|
|
|
op.type = "test"
|
|
|
|
op_proto.type = "test"
|
|
|
|
out = op.outputs.add()
|
|
|
|
out = op_proto.outputs.add()
|
|
|
|
out.name = "OUT"
|
|
|
|
out.name = "OUT"
|
|
|
|
out.comment = ""
|
|
|
|
out.comment = ""
|
|
|
|
|
|
|
|
|
|
|
|
out = op.outputs.add()
|
|
|
|
out = op_proto.outputs.add()
|
|
|
|
out.name = "TMP"
|
|
|
|
out.name = "TMP"
|
|
|
|
out.comment = ""
|
|
|
|
out.comment = ""
|
|
|
|
out.temporary = True
|
|
|
|
out.temporary = True
|
|
|
|
|
|
|
|
|
|
|
|
out = op.outputs.add()
|
|
|
|
out = op_proto.outputs.add()
|
|
|
|
out.name = "OUT2"
|
|
|
|
out.name = "OUT2"
|
|
|
|
out.comment = ""
|
|
|
|
out.comment = ""
|
|
|
|
op.comment = ""
|
|
|
|
op_proto.comment = ""
|
|
|
|
|
|
|
|
|
|
|
|
method = creation.OpDescCreationMethod(op)
|
|
|
|
method = op.OpDescCreationMethod(op_proto)
|
|
|
|
generated = method(OUT="a", OUT2="b")
|
|
|
|
generated = method(OUT="a", OUT2="b")
|
|
|
|
desc = op_desc_pb2.OpDesc()
|
|
|
|
desc = op_desc_pb2.OpDesc()
|
|
|
|
desc.outputs.extend(["a", core.var_names.temp(), "b"])
|
|
|
|
desc.outputs.extend(["a", core.var_names.temp(), "b"])
|
|
|
@ -190,60 +190,9 @@ class TestOpDescCreationMethod(unittest.TestCase):
|
|
|
|
self.assertEqual(generated, desc)
|
|
|
|
self.assertEqual(generated, desc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestOpCreationDocStr(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_all(self):
|
|
|
|
|
|
|
|
op = op_proto_pb2.OpProto()
|
|
|
|
|
|
|
|
op.type = "test"
|
|
|
|
|
|
|
|
op.comment = """Test Op.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This op is used for unit test, not a real op.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
a = op.inputs.add()
|
|
|
|
|
|
|
|
a.name = "a"
|
|
|
|
|
|
|
|
a.comment = "Input a for test op"
|
|
|
|
|
|
|
|
a.multiple = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b = op.inputs.add()
|
|
|
|
|
|
|
|
b.name = "b"
|
|
|
|
|
|
|
|
b.comment = "Input b for test op"
|
|
|
|
|
|
|
|
self.assertTrue(op.IsInitialized())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
o1 = op.outputs.add()
|
|
|
|
|
|
|
|
o1.name = "output"
|
|
|
|
|
|
|
|
o1.comment = "The output of test op"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
o2 = op.outputs.add()
|
|
|
|
|
|
|
|
o2.name = "temp output"
|
|
|
|
|
|
|
|
o2.comment = "The temporary output of test op"
|
|
|
|
|
|
|
|
o2.temporary = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_str = op.attrs.add()
|
|
|
|
|
|
|
|
test_str.name = "str_attr"
|
|
|
|
|
|
|
|
test_str.type = attr_type_pb2.STRING
|
|
|
|
|
|
|
|
test_str.comment = "A string attribute for test op"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actual = creation.get_docstring_from_op_proto(op)
|
|
|
|
|
|
|
|
expected_docstring = '''Test Op.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This op is used for unit test, not a real op.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param a: Input a for test op
|
|
|
|
|
|
|
|
:type a: list | basestr
|
|
|
|
|
|
|
|
:param b: Input b for test op
|
|
|
|
|
|
|
|
:type b: basestr
|
|
|
|
|
|
|
|
:param output: The output of test op
|
|
|
|
|
|
|
|
:type output: basestr
|
|
|
|
|
|
|
|
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
|
|
|
|
|
|
|
|
:type temp output: basestr
|
|
|
|
|
|
|
|
:param str_attr: A string attribute for test op
|
|
|
|
|
|
|
|
:type str_attr: basestr
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
self.assertEqual(expected_docstring, actual)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestOpCreations(unittest.TestCase):
|
|
|
|
class TestOpCreations(unittest.TestCase):
|
|
|
|
def test_all(self):
|
|
|
|
def test_all(self):
|
|
|
|
add_op = creation.op_creations.add_two(X="a", Y="b", Out="z")
|
|
|
|
add_op = op.Operator("add_two", X="a", Y="b", Out="z")
|
|
|
|
self.assertIsNotNone(add_op)
|
|
|
|
self.assertIsNotNone(add_op)
|
|
|
|
# Invoke C++ DebugString()
|
|
|
|
# Invoke C++ DebugString()
|
|
|
|
self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).',
|
|
|
|
self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).',
|