|
|
|
@ -21,6 +21,7 @@ import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.static import InputSpec
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid.layers.utils import flatten
|
|
|
|
|
from paddle.fluid.dygraph import Linear
|
|
|
|
|
from paddle.fluid.dygraph import declarative, ProgramTranslator
|
|
|
|
|
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
|
|
|
|
@ -153,6 +154,21 @@ class LinearNetReturnHidden(fluid.dygraph.Layer):
|
|
|
|
|
return y, loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearNetWithNestOut(fluid.dygraph.Layer):
|
|
|
|
|
def __init__(self, in_size, out_size):
|
|
|
|
|
super(LinearNetWithNestOut, self).__init__()
|
|
|
|
|
self._linear_1 = Linear(in_size, out_size)
|
|
|
|
|
self._linear_2 = Linear(in_size, out_size)
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
y = self._linear_1(x)
|
|
|
|
|
z = self._linear_2(y)
|
|
|
|
|
out = y + z
|
|
|
|
|
loss = fluid.layers.mean(out)
|
|
|
|
|
return y, [(z, loss), out]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyLayer(paddle.nn.Layer):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(EmptyLayer, self).__init__()
|
|
|
|
@ -299,6 +315,30 @@ class TestJitSaveLoad(unittest.TestCase):
|
|
|
|
|
loaded_layer = paddle.jit.load(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSaveLoadWithNestOut(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
# enable dygraph mode
|
|
|
|
|
fluid.enable_dygraph()
|
|
|
|
|
|
|
|
|
|
def test_nest_output(self):
|
|
|
|
|
x = fluid.dygraph.to_variable(
|
|
|
|
|
np.random.random((4, 8)).astype('float32'))
|
|
|
|
|
|
|
|
|
|
net = LinearNetWithNestOut(8, 8)
|
|
|
|
|
dy_outs = flatten(net(x))
|
|
|
|
|
net = declarative(net, input_spec=[InputSpec([None, 8], name='x')])
|
|
|
|
|
|
|
|
|
|
model_path = "net_with_nest_out/model"
|
|
|
|
|
paddle.jit.save(net, model_path)
|
|
|
|
|
|
|
|
|
|
load_net = paddle.jit.load(model_path)
|
|
|
|
|
load_outs = flatten(load_net(x))
|
|
|
|
|
|
|
|
|
|
self.assertTrue(len(dy_outs) == 4)
|
|
|
|
|
for dy_out, load_out in zip(dy_outs, load_outs):
|
|
|
|
|
self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSaveLoadWithInputSpec(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
# enable dygraph mode
|
|
|
|
|