You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/v2/framework/tests/test_infer_shape.py

30 lines
864 B

import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase):
def test_sum_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
# prepare input/output
x1 = block.new_var("x1")
x1.set_shape([10, 20])
x2 = block.new_var("x2")
x2.set_shape([10, 20])
out = block.new_var("out")
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("sum")
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
sum_op = Operator("sum", X=["x1", "x2"], Out="out")
sum_op.infer_shape(sum_op_desc, block)
print(out.shape())