|
|
@ -29,6 +29,12 @@ def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2):
|
|
|
|
print('Skip use_cuda=True because Paddle is not compiled with cuda')
|
|
|
|
print('Skip use_cuda=True because Paddle is not compiled with cuda')
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_parallel_executor and os.name == 'nt':
|
|
|
|
|
|
|
|
print(
|
|
|
|
|
|
|
|
'Skip use_parallel_executor=True because Paddle comes without parallel support on windows'
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
word_dict = paddle.dataset.imdb.word_dict()
|
|
|
|
word_dict = paddle.dataset.imdb.word_dict()
|
|
|
|
train_reader = paddle.batch(
|
|
|
|
train_reader = paddle.batch(
|
|
|
|
paddle.dataset.imdb.train(word_dict), batch_size=batch_size)
|
|
|
|
paddle.dataset.imdb.train(word_dict), batch_size=batch_size)
|
|
|
|