Fix unittest bugs

revert-4814-Add_sequence_project_op
fengjiayi 8 years ago
parent e71b836f53
commit 906f5e8a26

@ -160,7 +160,7 @@ class Operator(object):
(in_proto.name, len(in_argus)))
in_argu_names = []
for argu in in_argus:
in_argu_names.append(argu.name())
in_argu_names.append(argu.name)
self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None:
@ -174,7 +174,7 @@ class Operator(object):
(out_proto.name, len(out_argus)))
out_argu_names = []
for argu in out_argus:
out_argu_names.append(argu.name())
out_argu_names.append(argu.name)
self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None:

@ -14,7 +14,7 @@ class TestOperator(unittest.TestCase):
err.message,
"Operator with type \"no_such_op\" has not been registered.")
def test_input_output(self):
def test_op_desc_creation(self):
block = g_program.current_block()
mul_x = block.create_var(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
@ -26,12 +26,18 @@ class TestOperator(unittest.TestCase):
type="mul",
inputs={"X": [mul_x],
"Y": mul_y},
outputs={"Out": [mul_out]})
outputs={"Out": [mul_out]},
attrs={"x_num_col_dims": 1})
self.assertEqual(mul_op.type, "mul")
self.assertEqual(mul_op.input_names, ["X", "Y"])
self.assertEqual(mul_op.input("X"), ["x"])
self.assertEqual(mul_op.input("X"), ["mul.x"])
self.assertEqual(mul_op.input("Y"), ["mul.y"])
self.assertEqual(mul_op.output_names, ["Out"])
self.assertEqual(mul_op.output("Out"), ["out"])
self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual(mul_op.attr_names, ["x_num_col_dims"])
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
def test_mult_input(self):
block = g_program.current_block()
@ -49,9 +55,9 @@ class TestOperator(unittest.TestCase):
outputs={"Out": sum_out})
self.assertEqual(sum_op.type, "sum")
self.assertEqual(sum_op.input_names, ["X"])
self.assertEqual(sum_op.input("X"), ["x1", "x2", "x3"])
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["out"])
self.assertEqual(sum_op.output("Out"), ["sum.out"])
if __name__ == '__main__':

Loading…
Cancel
Save