|
|
|
@ -26,14 +26,17 @@ import six
|
|
|
|
|
import collections
|
|
|
|
|
from itertools import chain
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import nltk
|
|
|
|
|
from nltk.corpus import movie_reviews
|
|
|
|
|
import ssl
|
|
|
|
|
ssl._create_default_https_context = ssl._create_unverified_context
|
|
|
|
|
import zipfile
|
|
|
|
|
from functools import cmp_to_key
|
|
|
|
|
|
|
|
|
|
import paddle.dataset.common
|
|
|
|
|
|
|
|
|
|
URL = "https://corpora.bj.bcebos.com/movie_reviews%2Fmovie_reviews.zip"
|
|
|
|
|
MD5 = '155de2b77c6834dd8eea7cbe88e93acb'
|
|
|
|
|
|
|
|
|
|
__all__ = ['train', 'test', 'get_word_dict']
|
|
|
|
|
NUM_TRAINING_INSTANCES = 1600
|
|
|
|
|
NUM_TOTAL_INSTANCES = 2000
|
|
|
|
@ -44,6 +47,14 @@ def download_data_if_not_yet():
|
|
|
|
|
Download the data set, if the data set is not download.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# download and extract movie_reviews.zip
|
|
|
|
|
paddle.dataset.common.download(
|
|
|
|
|
URL, 'corpora', md5sum=MD5, save_name='movie_reviews.zip')
|
|
|
|
|
path = os.path.join(paddle.dataset.common.DATA_HOME, 'corpora')
|
|
|
|
|
filename = os.path.join(path, 'movie_reviews.zip')
|
|
|
|
|
zip_file = zipfile.ZipFile(filename)
|
|
|
|
|
zip_file.extractall(path)
|
|
|
|
|
zip_file.close()
|
|
|
|
|
# make sure that nltk can find the data
|
|
|
|
|
if paddle.dataset.common.DATA_HOME not in nltk.data.path:
|
|
|
|
|
nltk.data.path.append(paddle.dataset.common.DATA_HOME)
|
|
|
|
|