|
|
|
@ -18,6 +18,7 @@ import paddle.v2.attr as attr
|
|
|
|
|
import paddle.v2.data_type as data_type
|
|
|
|
|
import paddle.v2.layer as layer
|
|
|
|
|
import paddle.v2.pooling as pooling
|
|
|
|
|
import paddle.v2.networks as networks
|
|
|
|
|
|
|
|
|
|
pixel = layer.data(name='pixel', type=data_type.dense_vector(128))
|
|
|
|
|
label = layer.data(name='label', type=data_type.integer_value(10))
|
|
|
|
@ -251,5 +252,13 @@ class ProjOpTest(unittest.TestCase):
|
|
|
|
|
print layer.parse_network(conv1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NetworkTests(unittest.TestCase):
|
|
|
|
|
def test_vgg(self):
|
|
|
|
|
img = layer.data(name='pixel', type=data_type.dense_vector(784))
|
|
|
|
|
vgg_out = networks.small_vgg(
|
|
|
|
|
input_image=img, num_channels=1, num_classes=2)
|
|
|
|
|
print layer.parse_network(vgg_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|