|
|
|
@ -332,5 +332,31 @@ class TestDeclarativeAPI(unittest.TestCase):
|
|
|
|
|
func(np.ones(5).astype("int32"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDecorateModelDirectly(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
program_trans.enable(True)
|
|
|
|
|
self.x = to_variable(np.ones([4, 10]).astype('float32'))
|
|
|
|
|
|
|
|
|
|
def test_fake_input(self):
|
|
|
|
|
net = SimpleNet()
|
|
|
|
|
net = declarative(net)
|
|
|
|
|
y = net(self.x)
|
|
|
|
|
self.assertTrue(len(net.forward.program_cache) == 1)
|
|
|
|
|
|
|
|
|
|
def test_input_spec(self):
|
|
|
|
|
net = SimpleNet()
|
|
|
|
|
net = declarative(net, input_spec=[InputSpec([None, 8, 10])])
|
|
|
|
|
self.assertTrue(len(net.forward.inputs) == 1)
|
|
|
|
|
self.assertTrue(len(net.forward.program_cache) == 1)
|
|
|
|
|
input_shape = net.forward.inputs[0].shape
|
|
|
|
|
self.assertListEqual(list(input_shape), [-1, 8, 10])
|
|
|
|
|
|
|
|
|
|
# redecorate
|
|
|
|
|
net = declarative(net, input_spec=[InputSpec([None, 16, 10])])
|
|
|
|
|
input_shape = net.forward.inputs[0].shape
|
|
|
|
|
self.assertListEqual(list(input_shape), [-1, 16, 10])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|