|
|
@ -1,6 +1,6 @@
|
|
|
|
import unittest
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
from paddle.v2.framework.op import Operator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestInferShape(unittest.TestCase):
|
|
|
|
class TestInferShape(unittest.TestCase):
|
|
|
@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
|
|
|
|
sum_op_desc.set_input("X", ["x1", "x2"])
|
|
|
|
sum_op_desc.set_input("X", ["x1", "x2"])
|
|
|
|
sum_op_desc.set_output("Out", ["out"])
|
|
|
|
sum_op_desc.set_output("Out", ["out"])
|
|
|
|
|
|
|
|
|
|
|
|
core.Operator.infer_shape(sum_op_desc, block)
|
|
|
|
sum_op_desc.infer_shape(block)
|
|
|
|
self.assertEqual(out.shape(), shape)
|
|
|
|
self.assertEqual(out.shape(), shape)
|
|
|
|
|
|
|
|
|
|
|
|
def test_mul_op(self):
|
|
|
|
def test_mul_op(self):
|
|
|
@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
|
|
|
|
mul_op_desc.set_attr("x_num_col_dims", 1)
|
|
|
|
mul_op_desc.set_attr("x_num_col_dims", 1)
|
|
|
|
mul_op_desc.set_attr("y_num_col_dims", 1)
|
|
|
|
mul_op_desc.set_attr("y_num_col_dims", 1)
|
|
|
|
|
|
|
|
|
|
|
|
core.Operator.infer_shape(mul_op_desc, block)
|
|
|
|
mul_op_desc.infer_shape(block)
|
|
|
|
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
|
|
|
|
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|