|
|
|
@ -222,6 +222,16 @@ class LinearNetWithDictInput(paddle.nn.Layer):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearNetWithDictInputNoPrune(paddle.nn.Layer):
|
|
|
|
|
def __init__(self, in_size, out_size):
|
|
|
|
|
super(LinearNetWithDictInputNoPrune, self).__init__()
|
|
|
|
|
self._linear = Linear(in_size, out_size)
|
|
|
|
|
|
|
|
|
|
def forward(self, img):
|
|
|
|
|
out = self._linear(img['img'] + img['img2'])
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyLayer(paddle.nn.Layer):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(EmptyLayer, self).__init__()
|
|
|
|
@ -443,6 +453,30 @@ class TestSaveLoadWithDictInput(unittest.TestCase):
|
|
|
|
|
self.assertEqual(len(loaded_net._input_spec()), 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSaveLoadWithDictInputNoPrune(unittest.TestCase):
|
|
|
|
|
def test_dict_input(self):
|
|
|
|
|
net = LinearNetWithDictInputNoPrune(8, 8)
|
|
|
|
|
|
|
|
|
|
path = "test_jit_save_load_with_dict_input_no_prune/model"
|
|
|
|
|
# prune inputs
|
|
|
|
|
paddle.jit.save(
|
|
|
|
|
layer=net,
|
|
|
|
|
path=path,
|
|
|
|
|
input_spec=[{
|
|
|
|
|
'img': InputSpec(
|
|
|
|
|
shape=[None, 8], dtype='float32', name='img'),
|
|
|
|
|
'img2': InputSpec(
|
|
|
|
|
shape=[None, 8], dtype='float32', name='img2')
|
|
|
|
|
}])
|
|
|
|
|
|
|
|
|
|
img = paddle.randn(shape=[4, 8], dtype='float32')
|
|
|
|
|
img2 = paddle.randn(shape=[4, 8], dtype='float32')
|
|
|
|
|
loaded_net = paddle.jit.load(path)
|
|
|
|
|
loaded_out = loaded_net(img, img2)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(loaded_net._input_spec()), 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSaveLoadWithInputSpec(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
# enable dygraph mode
|
|
|
|
|