|
|
|
@ -22,13 +22,12 @@ import pickle
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
|
|
from mindspore.dataset.engine import GeneratorDataset
|
|
|
|
|
from mindspore.dataset import GeneratorDataset
|
|
|
|
|
|
|
|
|
|
import src.constants as rconst
|
|
|
|
|
import src.movielens as movielens
|
|
|
|
|
import src.stat_utils as stat_utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATASET_TO_NUM_USERS_AND_ITEMS = {
|
|
|
|
|
"ml-1m": (6040, 3706),
|
|
|
|
|
"ml-20m": (138493, 26744)
|
|
|
|
@ -205,6 +204,7 @@ class NCFDataset:
|
|
|
|
|
"""
|
|
|
|
|
A dataset for NCF network.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
pos_users,
|
|
|
|
|
pos_items,
|
|
|
|
@ -407,6 +407,7 @@ class RandomSampler:
|
|
|
|
|
"""
|
|
|
|
|
A random sampler for dataset.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, pos_count, num_train_negatives, batch_size):
|
|
|
|
|
self.pos_count = pos_count
|
|
|
|
|
self._num_samples = (1 + num_train_negatives) * self.pos_count
|
|
|
|
@ -433,6 +434,7 @@ class DistributedSamplerOfTrain:
|
|
|
|
|
"""
|
|
|
|
|
A distributed sampler for dataset.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size):
|
|
|
|
|
"""
|
|
|
|
|
Distributed sampler of training dataset.
|
|
|
|
@ -443,15 +445,16 @@ class DistributedSamplerOfTrain:
|
|
|
|
|
self._batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size))
|
|
|
|
|
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
|
|
|
|
|
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
|
|
|
|
|
self._total_num_samples = self._samples_per_rank * self._rank_size
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
"""
|
|
|
|
|
Returns the data after each sampling.
|
|
|
|
|
"""
|
|
|
|
|
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
|
|
|
|
|
indices = indices.tolist()
|
|
|
|
|
indices.extend(indices[:self._total_num_samples-len(indices)])
|
|
|
|
|
indices.extend(indices[:self._total_num_samples - len(indices)])
|
|
|
|
|
indices = indices[self._rank_id:self._total_num_samples:self._rank_size]
|
|
|
|
|
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)]
|
|
|
|
|
|
|
|
|
@ -463,10 +466,12 @@ class DistributedSamplerOfTrain:
|
|
|
|
|
"""
|
|
|
|
|
return self._batchs_per_rank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequenceSampler:
|
|
|
|
|
"""
|
|
|
|
|
A sequence sampler for dataset.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, eval_batch_size, num_users):
|
|
|
|
|
self._eval_users_per_batch = int(
|
|
|
|
|
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
|
|
|
@ -491,10 +496,12 @@ class SequenceSampler:
|
|
|
|
|
"""
|
|
|
|
|
return self._eval_batches_per_epoch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedSamplerOfEval:
|
|
|
|
|
"""
|
|
|
|
|
A distributed sampler for eval dataset.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, eval_batch_size, num_users, rank_id, rank_size):
|
|
|
|
|
self._eval_users_per_batch = int(
|
|
|
|
|
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
|
|
|
@ -507,8 +514,8 @@ class DistributedSamplerOfEval:
|
|
|
|
|
self._eval_batch_size = eval_batch_size
|
|
|
|
|
|
|
|
|
|
self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size))
|
|
|
|
|
#self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
|
|
|
|
|
#self._total_num_samples = self._samples_per_rank * self._rank_size
|
|
|
|
|
# self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
|
|
|
|
|
# self._total_num_samples = self._samples_per_rank * self._rank_size
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch)
|
|
|
|
@ -525,6 +532,7 @@ class DistributedSamplerOfEval:
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return self._batchs_per_rank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_eval_batch_size(eval_batch_size):
|
|
|
|
|
"""
|
|
|
|
|
Parse eval batch size.
|
|
|
|
|