|
|
|
@ -81,62 +81,6 @@ class TestAsyncExecutor(unittest.TestCase):
|
|
|
|
|
tarf.extractall(path='./')
|
|
|
|
|
tarf.close()
|
|
|
|
|
|
|
|
|
|
def test_data_feed_desc(self):
|
|
|
|
|
data_feed = fluid.DataFeedDesc('./data.prototxt')
|
|
|
|
|
# assertEqueal(data_feed.proto_desc.batch, 2)
|
|
|
|
|
# assertEqual(len(data_feed.proto_desc.multi_slot_desc), 2)
|
|
|
|
|
self.assertEqual(" ".join(data_feed.desc().split()),
|
|
|
|
|
" ".join(proto_str.split()))
|
|
|
|
|
|
|
|
|
|
def test_run(self):
|
|
|
|
|
# Initialize dataset description
|
|
|
|
|
data_feed = fluid.DataFeedDesc('train_data/data.prototxt')
|
|
|
|
|
data_feed.set_batch_size(
|
|
|
|
|
128) # See API doc for how to change other fields
|
|
|
|
|
|
|
|
|
|
# define network
|
|
|
|
|
# input text data
|
|
|
|
|
data = fluid.layers.data(
|
|
|
|
|
name="words", shape=[1], dtype="int64", lod_level=1)
|
|
|
|
|
# label data
|
|
|
|
|
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
|
|
|
|
|
|
|
|
|
|
avg_cost, acc, prediction = bow_net(data, label)
|
|
|
|
|
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002)
|
|
|
|
|
opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost)
|
|
|
|
|
|
|
|
|
|
# Run startup program
|
|
|
|
|
startup_program = fluid.default_startup_program()
|
|
|
|
|
place = fluid.CPUPlace()
|
|
|
|
|
executor = fluid.Executor(place)
|
|
|
|
|
executor.run(startup_program)
|
|
|
|
|
|
|
|
|
|
main_program = fluid.default_main_program()
|
|
|
|
|
async_executor = fluid.AsyncExecutor(place)
|
|
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, async_executor.run)
|
|
|
|
|
self.assertRaises(TypeError, async_executor.run, main_program)
|
|
|
|
|
self.assertRaises(TypeError, async_executor.run, main_program,
|
|
|
|
|
data_feed)
|
|
|
|
|
|
|
|
|
|
filelist = ['train_data/part-%d' % i for i in range(10)]
|
|
|
|
|
self.assertRaises(TypeError, async_executor.run, main_program,
|
|
|
|
|
data_feed, filelist)
|
|
|
|
|
|
|
|
|
|
thread_num = 4
|
|
|
|
|
self.assertRaises(TypeError, async_executor.run, main_program,
|
|
|
|
|
data_feed, filelist, thread_num)
|
|
|
|
|
|
|
|
|
|
async_executor.run(main_program, data_feed, filelist, thread_num, [acc])
|
|
|
|
|
fluid.io.save_inference_model("imdb.model", [data.name, label.name],
|
|
|
|
|
[acc], executor)
|
|
|
|
|
statinfo = os.stat('imdb.model/__model__')
|
|
|
|
|
self.assertGreater(statinfo.st_size, 0)
|
|
|
|
|
|
|
|
|
|
os.remove('./data.prototxt')
|
|
|
|
|
shutil.rmtree('./train_data')
|
|
|
|
|
shutil.rmtree('./imdb.model')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|