@ -1,6 +1,6 @@
import unittest
import unittest
import paddle . v2 . framework . core as core
import paddle . v2 . framework . core as core
from paddle . v2 . framework . op import Operator
class TestInferShape ( unittest . TestCase ) :
class TestInferShape ( unittest . TestCase ) :
@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc . set_input ( " X " , [ " x1 " , " x2 " ] )
sum_op_desc . set_input ( " X " , [ " x1 " , " x2 " ] )
sum_op_desc . set_output ( " Out " , [ " out " ] )
sum_op_desc . set_output ( " Out " , [ " out " ] )
core. Operator . infer_shape ( sum_op_desc , block )
sum_op_desc. infer_shape ( block )
self . assertEqual ( out . shape ( ) , shape )
self . assertEqual ( out . shape ( ) , shape )
def test_mul_op ( self ) :
def test_mul_op ( self ) :
@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc . set_attr ( " x_num_col_dims " , 1 )
mul_op_desc . set_attr ( " x_num_col_dims " , 1 )
mul_op_desc . set_attr ( " y_num_col_dims " , 1 )
mul_op_desc . set_attr ( " y_num_col_dims " , 1 )
core. Operator . infer_shape ( mul_op_desc , block )
mul_op_desc. infer_shape ( block )
self . assertEqual ( out . shape ( ) , [ x_shape [ 0 ] , y_shape [ 1 ] ] )
self . assertEqual ( out . shape ( ) , [ x_shape [ 0 ] , y_shape [ 1 ] ] )