|
|
|
@ -27,6 +27,8 @@ import paddle.dataset.common
|
|
|
|
|
import re
|
|
|
|
|
import random
|
|
|
|
|
import functools
|
|
|
|
|
import six
|
|
|
|
|
import paddle.fluid.compat as cpt
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id',
|
|
|
|
@ -112,6 +114,7 @@ def __initialize_meta_info__():
|
|
|
|
|
categories_set = set()
|
|
|
|
|
with package.open('ml-1m/movies.dat') as movie_file:
|
|
|
|
|
for i, line in enumerate(movie_file):
|
|
|
|
|
line = cpt.to_literal_str(line, encoding='latin')
|
|
|
|
|
movie_id, title, categories = line.strip().split('::')
|
|
|
|
|
categories = categories.split('|')
|
|
|
|
|
for c in categories:
|
|
|
|
@ -136,6 +139,7 @@ def __initialize_meta_info__():
|
|
|
|
|
USER_INFO = dict()
|
|
|
|
|
with package.open('ml-1m/users.dat') as user_file:
|
|
|
|
|
for line in user_file:
|
|
|
|
|
line = cpt.to_literal_str(line, encoding='latin')
|
|
|
|
|
uid, gender, age, job, _ = line.strip().split("::")
|
|
|
|
|
USER_INFO[int(uid)] = UserInfo(
|
|
|
|
|
index=uid, gender=gender, age=age, job_id=job)
|
|
|
|
@ -148,6 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
|
|
|
|
|
with zipfile.ZipFile(file=fn) as package:
|
|
|
|
|
with package.open('ml-1m/ratings.dat') as rating:
|
|
|
|
|
for line in rating:
|
|
|
|
|
line = cpt.to_literal_str(line, encoding='latin')
|
|
|
|
|
if (rand.random() < test_ratio) == is_test:
|
|
|
|
|
uid, mov_id, rating, _ = line.strip().split("::")
|
|
|
|
|
uid = int(uid)
|
|
|
|
@ -187,7 +192,7 @@ def max_movie_id():
|
|
|
|
|
Get the maximum value of movie id.
|
|
|
|
|
"""
|
|
|
|
|
__initialize_meta_info__()
|
|
|
|
|
return reduce(__max_index_info__, list(MOVIE_INFO.values())).index
|
|
|
|
|
return six.moves.reduce(__max_index_info__, list(MOVIE_INFO.values())).index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def max_user_id():
|
|
|
|
@ -195,7 +200,7 @@ def max_user_id():
|
|
|
|
|
Get the maximum value of user id.
|
|
|
|
|
"""
|
|
|
|
|
__initialize_meta_info__()
|
|
|
|
|
return reduce(__max_index_info__, list(USER_INFO.values())).index
|
|
|
|
|
return six.moves.reduce(__max_index_info__, list(USER_INFO.values())).index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __max_job_id_impl__(a, b):
|
|
|
|
@ -210,7 +215,7 @@ def max_job_id():
|
|
|
|
|
Get the maximum value of job id.
|
|
|
|
|
"""
|
|
|
|
|
__initialize_meta_info__()
|
|
|
|
|
return reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id
|
|
|
|
|
return six.moves.reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def movie_categories():
|
|
|
|
|