[Dy2stat] support usage: to_static(model) (#27040)

* support to_static(model)

* add warning and unittest
disable_ut_1
Aurelius84 5 years ago committed by GitHub
parent 1b84c0bf43
commit 5e0dde02b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -212,7 +212,16 @@ def declarative(function=None, input_spec=None):
# for usage: `declarative(foo, ...)`
if function is not None:
return decorated(function)
if isinstance(function, Layer):
if isinstance(function.forward, StaticLayer):
class_name = function.__class__.__name__
warnings.warn(
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
format(class_name))
function.forward = decorated(function.forward)
return function
else:
return decorated(function)
# for usage: `@declarative`
return decorated

@ -332,5 +332,31 @@ class TestDeclarativeAPI(unittest.TestCase):
func(np.ones(5).astype("int32"))
class TestDecorateModelDirectly(unittest.TestCase):
def setUp(self):
paddle.disable_static()
program_trans.enable(True)
self.x = to_variable(np.ones([4, 10]).astype('float32'))
def test_fake_input(self):
net = SimpleNet()
net = declarative(net)
y = net(self.x)
self.assertTrue(len(net.forward.program_cache) == 1)
def test_input_spec(self):
net = SimpleNet()
net = declarative(net, input_spec=[InputSpec([None, 8, 10])])
self.assertTrue(len(net.forward.inputs) == 1)
self.assertTrue(len(net.forward.program_cache) == 1)
input_shape = net.forward.inputs[0].shape
self.assertListEqual(list(input_shape), [-1, 8, 10])
# redecorate
net = declarative(net, input_spec=[InputSpec([None, 16, 10])])
input_shape = net.forward.inputs[0].shape
self.assertListEqual(list(input_shape), [-1, 16, 10])
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save