|
|
|
@ -13,59 +13,92 @@
|
|
|
|
|
# limitations under the license.
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
import os
|
|
|
|
|
import six
|
|
|
|
|
import unittest
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import six
|
|
|
|
|
from paddle.fluid.framework import IrGraph
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
|
os.environ["CPU_NUM"] = "1"
|
|
|
|
|
|
|
|
|
|
def residual_block(num):
|
|
|
|
|
def conv_bn_layer(input,
|
|
|
|
|
ch_out,
|
|
|
|
|
filter_size,
|
|
|
|
|
stride,
|
|
|
|
|
padding,
|
|
|
|
|
act='relu',
|
|
|
|
|
bias_attr=False):
|
|
|
|
|
tmp = fluid.layers.conv2d(
|
|
|
|
|
input=input,
|
|
|
|
|
filter_size=filter_size,
|
|
|
|
|
num_filters=ch_out,
|
|
|
|
|
stride=stride,
|
|
|
|
|
padding=padding,
|
|
|
|
|
act=None,
|
|
|
|
|
bias_attr=bias_attr)
|
|
|
|
|
return fluid.layers.batch_norm(input=tmp, act=act)
|
|
|
|
|
|
|
|
|
|
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
|
|
|
|
|
def conv_block():
|
|
|
|
|
img = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
|
|
|
|
|
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
|
|
|
|
hidden = data
|
|
|
|
|
for _ in six.moves.xrange(num):
|
|
|
|
|
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
|
|
|
|
|
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
|
|
|
|
|
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
|
|
|
|
|
fc = fluid.layers.fc(input=hidden, size=10)
|
|
|
|
|
loss = fluid.layers.cross_entropy(input=fc, label=label)
|
|
|
|
|
loss = fluid.layers.mean(loss)
|
|
|
|
|
return loss
|
|
|
|
|
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
|
|
|
|
input=img,
|
|
|
|
|
filter_size=5,
|
|
|
|
|
num_filters=20,
|
|
|
|
|
pool_size=2,
|
|
|
|
|
pool_stride=2,
|
|
|
|
|
act="relu")
|
|
|
|
|
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
|
|
|
|
|
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
|
|
|
|
input=conv_pool_1,
|
|
|
|
|
filter_size=5,
|
|
|
|
|
num_filters=50,
|
|
|
|
|
pool_size=2,
|
|
|
|
|
pool_stride=2,
|
|
|
|
|
act="relu")
|
|
|
|
|
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
|
|
|
|
|
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
|
|
|
|
avg_loss = fluid.layers.mean(loss)
|
|
|
|
|
return [img, label], avg_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGraph(unittest.TestCase):
|
|
|
|
|
def test_graph_functions(self, for_ci=True):
|
|
|
|
|
def graph_apis(self, use_cuda=False, for_ci=True):
|
|
|
|
|
main = fluid.Program()
|
|
|
|
|
startup = fluid.Program()
|
|
|
|
|
with fluid.program_guard(main, startup):
|
|
|
|
|
loss = residual_block(2)
|
|
|
|
|
feeds, loss = conv_block()
|
|
|
|
|
opt = fluid.optimizer.Adam(learning_rate=0.001)
|
|
|
|
|
opt.minimize(loss)
|
|
|
|
|
graph = IrGraph(core.Graph(main.desc), for_test=False)
|
|
|
|
|
backup_graph = graph.clone()
|
|
|
|
|
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
|
|
|
|
|
build_strategy = fluid.BuildStrategy()
|
|
|
|
|
build_strategy.memory_optimize = False
|
|
|
|
|
build_strategy.enable_inplace = False
|
|
|
|
|
origin_binary = fluid.CompiledProgram(graph.graph).with_data_parallel(
|
|
|
|
|
loss_name=loss.name, build_strategy=build_strategy)
|
|
|
|
|
backup_binary = fluid.CompiledProgram(
|
|
|
|
|
backup_graph.graph).with_data_parallel(
|
|
|
|
|
loss_name=loss.name, build_strategy=build_strategy)
|
|
|
|
|
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(startup)
|
|
|
|
|
iters = 5
|
|
|
|
|
batch_size = 8
|
|
|
|
|
train_reader = paddle.batch(
|
|
|
|
|
paddle.dataset.mnist.train(), batch_size=batch_size)
|
|
|
|
|
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
|
|
|
|
|
|
|
|
|
|
def train(binary):
|
|
|
|
|
for _ in range(iters):
|
|
|
|
|
data = next(train_reader())
|
|
|
|
|
loss_v = exe.run(binary,
|
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
|
fetch_list=[loss.name])
|
|
|
|
|
print('{}: {}'.format('loss', loss_v))
|
|
|
|
|
|
|
|
|
|
train(origin_binary)
|
|
|
|
|
train(backup_binary)
|
|
|
|
|
|
|
|
|
|
marked_nodes = set()
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name().find('conv2d') > -1:
|
|
|
|
|
marked_nodes.add(op)
|
|
|
|
|
if not for_ci:
|
|
|
|
|
graph.draw('.', 'residual', marked_nodes)
|
|
|
|
|
backup_marked_nodes = set()
|
|
|
|
|
for op in backup_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('conv2d') > -1:
|
|
|
|
|
backup_marked_nodes.add(op)
|
|
|
|
|
backup_graph.draw('.', 'backup', backup_marked_nodes)
|
|
|
|
|
self.assertFalse(graph.has_circle())
|
|
|
|
|
self.assertEqual(graph.graph_num(), 1)
|
|
|
|
|
nodes = graph.topology_sort()
|
|
|
|
@ -75,14 +108,13 @@ class TestGraph(unittest.TestCase):
|
|
|
|
|
nodes_num = len(graph.all_nodes())
|
|
|
|
|
graph.safe_remove_nodes(marked_nodes)
|
|
|
|
|
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
|
|
|
|
|
backup_graph = graph.clone()
|
|
|
|
|
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
|
|
|
|
|
if not for_ci:
|
|
|
|
|
backup_marked_nodes = set()
|
|
|
|
|
for op in backup_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('conv2d') > -1:
|
|
|
|
|
backup_marked_nodes.add(op)
|
|
|
|
|
backup_graph.draw('.', 'backup', backup_marked_nodes)
|
|
|
|
|
|
|
|
|
|
def test_graph_apis_cpu(self):
|
|
|
|
|
self.graph_apis(use_cuda=False, for_ci=True)
|
|
|
|
|
|
|
|
|
|
def test_graph_apis_cuda(self):
|
|
|
|
|
if fluid.core.is_compiled_with_cuda():
|
|
|
|
|
self.graph_apis(use_cuda=True, for_ci=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|