@ -17,8 +17,6 @@ import os
import pickle
import collections
import argparse
import urllib . request
import tarfile
import numpy as np
from mindspore . mindrecord import FileWriter
@ -140,7 +138,7 @@ def mkdir_path(file_path):
os . makedirs ( file_path )
def statsdata ( file_path , dict_output_path , criteo _stats_dict, dense_dim = 13 , slot_dim = 26 ) :
def statsdata ( file_path , dict_output_path , recommendation_dataset _stats_dict, dense_dim = 13 , slot_dim = 26 ) :
""" Preprocess data and save data """
with open ( file_path , encoding = " utf-8 " ) as file_in :
errorline_list = [ ]
@ -161,13 +159,13 @@ def statsdata(file_path, dict_output_path, criteo_stats_dict, dense_dim=13, slot
assert len ( values ) == dense_dim , " values.size: {} " . format ( len ( values ) )
assert len ( cats ) == slot_dim , " cats.size: {} " . format ( len ( cats ) )
criteo _stats_dict. stats_vals ( values )
criteo _stats_dict. stats_cats ( cats )
criteo _stats_dict. save_dict ( dict_output_path )
recommendation_dataset _stats_dict. stats_vals ( values )
recommendation_dataset _stats_dict. stats_cats ( cats )
recommendation_dataset _stats_dict. save_dict ( dict_output_path )
def random_split_trans2mindrecord ( input_file_path , output_file_path , criteo_stats_dict, part_rows = 2000000 ,
line_per_sample= 1000 , train_line_count = None ,
def random_split_trans2mindrecord ( input_file_path , output_file_path , recommendation_dataset_stats_dict ,
part_rows= 2000000 , line_per_sample= 1000 , train_line_count = None ,
test_size = 0.1 , seed = 2020 , dense_dim = 13 , slot_dim = 26 ) :
""" Random split data and save mindrecord """
if train_line_count is None :
@ -216,7 +214,7 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
assert len ( values ) == dense_dim , " values.size: {} " . format ( len ( values ) )
assert len ( cats ) == slot_dim , " cats.size: {} " . format ( len ( cats ) )
ids , wts = criteo _stats_dict. map_cat2id ( values , cats )
ids , wts = recommendation_dataset _stats_dict. map_cat2id ( values , cats )
ids_list . extend ( ids )
wts_list . extend ( wts )
@ -261,10 +259,8 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( description = " criteo data " )
parser . add_argument ( " --data_type " , type = str , default = ' criteo ' , choices = [ ' criteo ' , ' synthetic ' ] ,
help = ' Currently we support criteo dataset and synthetic dataset ' )
parser . add_argument ( " --data_path " , type = str , default = " ./criteo_data/ " , help = ' The path of the data file ' )
parser = argparse . ArgumentParser ( description = " Recommendation dataset " )
parser . add_argument ( " --data_path " , type = str , default = " ./recommendation_dataset/ " , help = ' The path of the data file ' )
parser . add_argument ( " --dense_dim " , type = int , default = 13 , help = ' The number of your continues fields ' )
parser . add_argument ( " --slot_dim " , type = int , default = 26 ,
help = ' The number of your sparse fields, it can also be called catelogy features. ' )
@ -277,19 +273,6 @@ if __name__ == '__main__':
args , _ = parser . parse_known_args ( )
data_path = args . data_path
if args . data_type == ' criteo ' :
download_data_path = data_path + " origin_data/ "
mkdir_path ( download_data_path )
url = " https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz "
file_name = download_data_path + ' / ' + url . split ( ' / ' ) [ - 1 ]
urllib . request . urlretrieve ( url , filename = file_name )
tar = tarfile . open ( file_name )
names = tar . getnames ( )
for name in names :
tar . extract ( name , path = download_data_path )
tar . close ( )
target_field_size = args . dense_dim + args . slot_dim
stats = StatsDict ( field_size = target_field_size , dense_dim = args . dense_dim , slot_dim = args . slot_dim ,
skip_id_convert = args . skip_id_convert )