|
|
|
@ -40,10 +40,14 @@ class TestOperator(unittest.TestCase):
|
|
|
|
|
self.assertEqual(mul_op.input("Y"), ["mul.y"])
|
|
|
|
|
self.assertEqual(mul_op.output_names, ["Out"])
|
|
|
|
|
self.assertEqual(mul_op.output("Out"), ["mul.out"])
|
|
|
|
|
self.assertEqual(mul_op.attr_names, ["x_num_col_dims"])
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"]))
|
|
|
|
|
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
|
|
|
|
|
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
|
|
|
|
|
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
|
|
|
|
|
self.assertEqual(mul_op.has_attr("y_num_col_dims"), True)
|
|
|
|
|
self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT)
|
|
|
|
|
self.assertEqual(mul_op.attr("y_num_col_dims"), 1)
|
|
|
|
|
self.assertEqual(mul_out.op, mul_op)
|
|
|
|
|
|
|
|
|
|
def test_mult_input(self):
|
|
|
|
|