|
|
|
@ -14,7 +14,7 @@ class TestOperator(unittest.TestCase):
|
|
|
|
|
err.message,
|
|
|
|
|
"Operator with type \"no_such_op\" has not been registered.")
|
|
|
|
|
|
|
|
|
|
def test_input_output(self):
|
|
|
|
|
def test_op_desc_creation(self):
|
|
|
|
|
block = g_program.current_block()
|
|
|
|
|
mul_x = block.create_var(
|
|
|
|
|
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
|
|
|
|
@ -26,12 +26,18 @@ class TestOperator(unittest.TestCase):
|
|
|
|
|
type="mul",
|
|
|
|
|
inputs={"X": [mul_x],
|
|
|
|
|
"Y": mul_y},
|
|
|
|
|
outputs={"Out": [mul_out]})
|
|
|
|
|
outputs={"Out": [mul_out]},
|
|
|
|
|
attrs={"x_num_col_dims": 1})
|
|
|
|
|
self.assertEqual(mul_op.type, "mul")
|
|
|
|
|
self.assertEqual(mul_op.input_names, ["X", "Y"])
|
|
|
|
|
self.assertEqual(mul_op.input("X"), ["x"])
|
|
|
|
|
self.assertEqual(mul_op.input("X"), ["mul.x"])
|
|
|
|
|
self.assertEqual(mul_op.input("Y"), ["mul.y"])
|
|
|
|
|
self.assertEqual(mul_op.output_names, ["Out"])
|
|
|
|
|
self.assertEqual(mul_op.output("Out"), ["out"])
|
|
|
|
|
self.assertEqual(mul_op.output("Out"), ["mul.out"])
|
|
|
|
|
self.assertEqual(mul_op.attr_names, ["x_num_col_dims"])
|
|
|
|
|
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
|
|
|
|
|
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
|
|
|
|
|
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
|
|
|
|
|
|
|
|
|
|
def test_mult_input(self):
|
|
|
|
|
block = g_program.current_block()
|
|
|
|
@ -49,9 +55,9 @@ class TestOperator(unittest.TestCase):
|
|
|
|
|
outputs={"Out": sum_out})
|
|
|
|
|
self.assertEqual(sum_op.type, "sum")
|
|
|
|
|
self.assertEqual(sum_op.input_names, ["X"])
|
|
|
|
|
self.assertEqual(sum_op.input("X"), ["x1", "x2", "x3"])
|
|
|
|
|
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
|
|
|
|
|
self.assertEqual(sum_op.output_names, ["Out"])
|
|
|
|
|
self.assertEqual(sum_op.output("Out"), ["out"])
|
|
|
|
|
self.assertEqual(sum_op.output("Out"), ["sum.out"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|