|
|
|
@ -16,6 +16,7 @@ from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid import compiler
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
@ -58,12 +59,13 @@ class TestFetchAndFeed(unittest.TestCase):
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(startup)
|
|
|
|
|
|
|
|
|
|
pe = fluid.ParallelExecutor(
|
|
|
|
|
use_cuda=use_cuda, loss_name=loss.name, main_program=main_program)
|
|
|
|
|
run_parallel_exe(main_program, pe, use_cuda, data, label, loss)
|
|
|
|
|
train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
|
|
|
|
|
loss_name=loss.name)
|
|
|
|
|
|
|
|
|
|
def run_parallel_exe_with_fetch(self, main, pe, use_cuda, data, label,
|
|
|
|
|
loss):
|
|
|
|
|
run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
|
|
|
|
|
|
|
|
|
|
def run_parallel_exe_with_fetch(self, compiled_program, exe, use_cuda, data,
|
|
|
|
|
label, loss):
|
|
|
|
|
def get_data(batch_size=8):
|
|
|
|
|
np.random.seed(5)
|
|
|
|
|
while True:
|
|
|
|
@ -78,7 +80,7 @@ class TestFetchAndFeed(unittest.TestCase):
|
|
|
|
|
# conv2d_1.b_0@GRAD. Those variables should not be pruned.
|
|
|
|
|
# fluid.memory_optimize(main)
|
|
|
|
|
fetch_list = []
|
|
|
|
|
all_vars = main.global_block().vars
|
|
|
|
|
all_vars = compiled_program._program.global_block().vars
|
|
|
|
|
|
|
|
|
|
for k, v in all_vars.items():
|
|
|
|
|
if ('tmp' not in k) and (
|
|
|
|
@ -89,14 +91,18 @@ class TestFetchAndFeed(unittest.TestCase):
|
|
|
|
|
for batch_id, img_label in enumerate(get_data()):
|
|
|
|
|
img, l = img_label
|
|
|
|
|
train_inputs = {data.name: img, label.name: l}
|
|
|
|
|
ret = pe.run(fetch_list, feed=train_inputs, return_numpy=True)
|
|
|
|
|
ret = exe.run(compiled_program,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
feed=train_inputs,
|
|
|
|
|
return_numpy=True)
|
|
|
|
|
for i in range(len(fetch_list)):
|
|
|
|
|
assert not math.isnan(np.sum(ret[i])) and \
|
|
|
|
|
not math.isinf(np.sum(ret[i]))
|
|
|
|
|
if batch_id == 2:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
def run_parallel_exe_with_feed(self, main, pe, use_cuda, data, label, loss):
|
|
|
|
|
def run_parallel_exe_with_feed(self, compiled_program, exe, use_cuda, data,
|
|
|
|
|
label, loss):
|
|
|
|
|
def get_data(batch_size=8):
|
|
|
|
|
np.random.seed(5)
|
|
|
|
|
while True:
|
|
|
|
@ -114,7 +120,9 @@ class TestFetchAndFeed(unittest.TestCase):
|
|
|
|
|
reader = feeder.decorate_reader(get_data, multi_devices=True)
|
|
|
|
|
|
|
|
|
|
for batch_id, data in enumerate(reader()):
|
|
|
|
|
loss_np = pe.run(feed=data, fetch_list=[loss.name])[0]
|
|
|
|
|
loss_np = exe.run(compiled_program,
|
|
|
|
|
feed=data,
|
|
|
|
|
fetch_list=[loss.name])[0]
|
|
|
|
|
print(batch_id, loss_np)
|
|
|
|
|
if batch_id == 2:
|
|
|
|
|
break
|
|
|
|
|