|
|
|
@ -17,13 +17,25 @@ from __future__ import print_function
|
|
|
|
|
import unittest
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyLayer(fluid.Layer):
|
|
|
|
|
def __init__(self, num_stacked_param):
|
|
|
|
|
def __init__(self, num_stacked_param, use_fluid_api):
|
|
|
|
|
super(MyLayer, self).__init__()
|
|
|
|
|
# create ParameterList with iterable Parameters
|
|
|
|
|
self.params = fluid.dygraph.ParameterList(
|
|
|
|
|
self.params = self.fluid_dygraph_ParameterList(
|
|
|
|
|
num_stacked_param
|
|
|
|
|
) if use_fluid_api else self.paddle_imperative_ParameterList(
|
|
|
|
|
num_stacked_param)
|
|
|
|
|
|
|
|
|
|
def fluid_dygraph_ParameterList(self, num_stacked_param):
|
|
|
|
|
return fluid.dygraph.ParameterList(
|
|
|
|
|
[fluid.layers.create_parameter(
|
|
|
|
|
shape=[2, 2], dtype='float32')] * num_stacked_param)
|
|
|
|
|
|
|
|
|
|
def paddle_imperative_ParameterList(self, num_stacked_param):
|
|
|
|
|
return paddle.imperative.ParameterList(
|
|
|
|
|
[fluid.layers.create_parameter(
|
|
|
|
|
shape=[2, 2], dtype='float32')] * num_stacked_param)
|
|
|
|
|
|
|
|
|
@ -42,12 +54,12 @@ class MyLayer(fluid.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestImperativeContainerParameterList(unittest.TestCase):
|
|
|
|
|
def test_paramter_list(self):
|
|
|
|
|
def paramter_list(self, use_fluid_api):
|
|
|
|
|
data_np = np.random.uniform(-1, 1, [5, 2]).astype('float32')
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
x = fluid.dygraph.to_variable(data_np)
|
|
|
|
|
num_stacked_param = 4
|
|
|
|
|
model = MyLayer(num_stacked_param)
|
|
|
|
|
model = MyLayer(num_stacked_param, use_fluid_api)
|
|
|
|
|
self.assertEqual(len(model.params), num_stacked_param)
|
|
|
|
|
res = model(x)
|
|
|
|
|
self.assertListEqual(res.shape, [5, 2])
|
|
|
|
@ -67,6 +79,10 @@ class TestImperativeContainerParameterList(unittest.TestCase):
|
|
|
|
|
loss = fluid.layers.reduce_mean(res)
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
def test_paramter_list(self):
|
|
|
|
|
self.paramter_list(True)
|
|
|
|
|
self.paramter_list(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|