|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
import random
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
@ -23,16 +24,17 @@ import paddle.fluid.core as core
|
|
|
|
|
from test_imperative_base import new_program_scope
|
|
|
|
|
from paddle.fluid.imperative.base import to_variable
|
|
|
|
|
|
|
|
|
|
NUM_USERS = 100
|
|
|
|
|
NUM_ITEMS = 1000
|
|
|
|
|
# Can use Amusic dataset as the DeepCF describes.
|
|
|
|
|
DATA_PATH = os.environ.get('DATA_PATH', '')
|
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 32
|
|
|
|
|
NUM_BATCHES = 2
|
|
|
|
|
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 128))
|
|
|
|
|
NUM_BATCHES = int(os.environ.get('NUM_BATCHES', 5))
|
|
|
|
|
NUM_EPOCHES = int(os.environ.get('NUM_EPOCHES', 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLP(fluid.imperative.Layer):
|
|
|
|
|
class DMF(fluid.imperative.Layer):
|
|
|
|
|
def __init__(self, name_scope):
|
|
|
|
|
super(MLP, self).__init__(name_scope)
|
|
|
|
|
super(DMF, self).__init__(name_scope)
|
|
|
|
|
self._user_latent = fluid.imperative.FC(self.full_name(), 256)
|
|
|
|
|
self._item_latent = fluid.imperative.FC(self.full_name(), 256)
|
|
|
|
|
|
|
|
|
@ -61,9 +63,9 @@ class MLP(fluid.imperative.Layer):
|
|
|
|
|
return fluid.layers.elementwise_mul(users, items)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DMF(fluid.imperative.Layer):
|
|
|
|
|
class MLP(fluid.imperative.Layer):
|
|
|
|
|
def __init__(self, name_scope):
|
|
|
|
|
super(DMF, self).__init__(name_scope)
|
|
|
|
|
super(MLP, self).__init__(name_scope)
|
|
|
|
|
self._user_latent = fluid.imperative.FC(self.full_name(), 256)
|
|
|
|
|
self._item_latent = fluid.imperative.FC(self.full_name(), 256)
|
|
|
|
|
self._match_layers = []
|
|
|
|
@ -87,21 +89,30 @@ class DMF(fluid.imperative.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepCF(fluid.imperative.Layer):
|
|
|
|
|
def __init__(self, name_scope):
|
|
|
|
|
def __init__(self, name_scope, num_users, num_items, matrix):
|
|
|
|
|
super(DeepCF, self).__init__(name_scope)
|
|
|
|
|
|
|
|
|
|
self._user_emb = fluid.imperative.Embedding(self.full_name(),
|
|
|
|
|
[NUM_USERS, 256])
|
|
|
|
|
self._item_emb = fluid.imperative.Embedding(self.full_name(),
|
|
|
|
|
[NUM_ITEMS, 256])
|
|
|
|
|
self._num_users = num_users
|
|
|
|
|
self._num_items = num_items
|
|
|
|
|
self._rating_matrix = self.create_parameter(
|
|
|
|
|
fluid.ParamAttr(trainable=False),
|
|
|
|
|
matrix.shape,
|
|
|
|
|
matrix.dtype,
|
|
|
|
|
is_bias=False,
|
|
|
|
|
default_initializer=fluid.initializer.NumpyArrayInitializer(matrix))
|
|
|
|
|
self._rating_matrix._stop_gradient = True
|
|
|
|
|
|
|
|
|
|
self._mlp = MLP(self.full_name())
|
|
|
|
|
self._dmf = DMF(self.full_name())
|
|
|
|
|
self._match_fc = fluid.imperative.FC(self.full_name(), 1, act='sigmoid')
|
|
|
|
|
|
|
|
|
|
def forward(self, users, items):
|
|
|
|
|
users_emb = self._user_emb(users)
|
|
|
|
|
items_emb = self._item_emb(items)
|
|
|
|
|
# users_emb = self._user_emb(users)
|
|
|
|
|
# items_emb = self._item_emb(items)
|
|
|
|
|
users_emb = fluid.layers.gather(self._rating_matrix, users)
|
|
|
|
|
items_emb = fluid.layers.gather(
|
|
|
|
|
fluid.layers.transpose(self._rating_matrix, [1, 0]), items)
|
|
|
|
|
users_emb.stop_gradient = True
|
|
|
|
|
items_emb.stop_gradient = True
|
|
|
|
|
|
|
|
|
|
mlp_predictive = self._mlp(users_emb, items_emb)
|
|
|
|
|
dmf_predictive = self._dmf(users_emb, items_emb)
|
|
|
|
@ -116,27 +127,79 @@ def get_data():
|
|
|
|
|
user_ids = []
|
|
|
|
|
item_ids = []
|
|
|
|
|
labels = []
|
|
|
|
|
NUM_USERS = 100
|
|
|
|
|
NUM_ITEMS = 1000
|
|
|
|
|
matrix = np.zeros([NUM_USERS, NUM_ITEMS], dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
for uid in range(NUM_USERS):
|
|
|
|
|
for iid in range(NUM_ITEMS):
|
|
|
|
|
# 10% positive
|
|
|
|
|
label = float(random.randint(1, 10) == 1)
|
|
|
|
|
label = float(random.randint(1, 6) == 1)
|
|
|
|
|
user_ids.append(uid)
|
|
|
|
|
item_ids.append(iid)
|
|
|
|
|
labels.append(label)
|
|
|
|
|
indices = np.arange(NUM_USERS * NUM_ITEMS)
|
|
|
|
|
matrix[uid, iid] = label
|
|
|
|
|
indices = np.arange(len(user_ids))
|
|
|
|
|
np.random.shuffle(indices)
|
|
|
|
|
users_np = np.array(user_ids, dtype=np.int32)[indices]
|
|
|
|
|
items_np = np.array(item_ids, dtype=np.int32)[indices]
|
|
|
|
|
labels_np = np.array(labels, dtype=np.float32)[indices]
|
|
|
|
|
return np.expand_dims(users_np, -1), \
|
|
|
|
|
np.expand_dims(items_np, -1), \
|
|
|
|
|
np.expand_dims(labels_np, -1), NUM_USERS, NUM_ITEMS, matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(DATA_PATH):
|
|
|
|
|
sys.stderr.write('loading from %s\n' % DATA_PATH)
|
|
|
|
|
likes = dict()
|
|
|
|
|
num_users = -1
|
|
|
|
|
num_items = -1
|
|
|
|
|
with open(DATA_PATH, 'r') as f:
|
|
|
|
|
for l in f.readlines():
|
|
|
|
|
uid, iid, rating = [int(v) for v in l.split('\t')]
|
|
|
|
|
num_users = max(num_users, uid + 1)
|
|
|
|
|
num_items = max(num_items, iid + 1)
|
|
|
|
|
if float(rating) > 0.0:
|
|
|
|
|
likes[(uid, iid)] = 1.0
|
|
|
|
|
|
|
|
|
|
user_ids = []
|
|
|
|
|
item_ids = []
|
|
|
|
|
labels = []
|
|
|
|
|
matrix = np.zeros([num_users, num_items], dtype=np.float32)
|
|
|
|
|
for uid, iid in likes.keys():
|
|
|
|
|
user_ids.append(uid)
|
|
|
|
|
item_ids.append(iid)
|
|
|
|
|
labels.append(1.0)
|
|
|
|
|
matrix[uid, iid] = 1.0
|
|
|
|
|
|
|
|
|
|
negative = 0
|
|
|
|
|
while negative < 3:
|
|
|
|
|
nuid = random.randint(0, num_users - 1)
|
|
|
|
|
niid = random.randint(0, num_items - 1)
|
|
|
|
|
if (nuid, niid) not in likes:
|
|
|
|
|
negative += 1
|
|
|
|
|
user_ids.append(nuid)
|
|
|
|
|
item_ids.append(niid)
|
|
|
|
|
labels.append(0.0)
|
|
|
|
|
|
|
|
|
|
indices = np.arange(len(user_ids))
|
|
|
|
|
np.random.shuffle(indices)
|
|
|
|
|
users_np = np.array(user_ids, dtype=np.int64)[indices]
|
|
|
|
|
items_np = np.array(item_ids, dtype=np.int64)[indices]
|
|
|
|
|
users_np = np.array(user_ids, dtype=np.int32)[indices]
|
|
|
|
|
items_np = np.array(item_ids, dtype=np.int32)[indices]
|
|
|
|
|
labels_np = np.array(labels, dtype=np.float32)[indices]
|
|
|
|
|
return np.expand_dims(users_np, -1), \
|
|
|
|
|
np.expand_dims(items_np, -1), \
|
|
|
|
|
np.expand_dims(labels_np, -1)
|
|
|
|
|
np.expand_dims(labels_np, -1), num_users, num_items, matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestImperativeDeepCF(unittest.TestCase):
|
|
|
|
|
def test_gan_float32(self):
|
|
|
|
|
def test_deefcf(self):
|
|
|
|
|
seed = 90
|
|
|
|
|
users_np, items_np, labels_np = get_data()
|
|
|
|
|
if DATA_PATH:
|
|
|
|
|
(users_np, items_np, labels_np, num_users, num_items,
|
|
|
|
|
matrix) = load_data(DATA_PATH)
|
|
|
|
|
else:
|
|
|
|
|
(users_np, items_np, labels_np, num_users, num_items,
|
|
|
|
|
matrix) = get_data()
|
|
|
|
|
|
|
|
|
|
startup = fluid.Program()
|
|
|
|
|
startup.random_seed = seed
|
|
|
|
@ -145,11 +208,11 @@ class TestImperativeDeepCF(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
scope = fluid.core.Scope()
|
|
|
|
|
with new_program_scope(main=main, startup=startup, scope=scope):
|
|
|
|
|
users = fluid.layers.data('users', [1], dtype='int64')
|
|
|
|
|
items = fluid.layers.data('items', [1], dtype='int64')
|
|
|
|
|
users = fluid.layers.data('users', [1], dtype='int32')
|
|
|
|
|
items = fluid.layers.data('items', [1], dtype='int32')
|
|
|
|
|
labels = fluid.layers.data('labels', [1], dtype='float32')
|
|
|
|
|
|
|
|
|
|
deepcf = DeepCF('deepcf')
|
|
|
|
|
deepcf = DeepCF('deepcf', num_users, num_items, matrix)
|
|
|
|
|
prediction = deepcf(users, items)
|
|
|
|
|
loss = fluid.layers.reduce_sum(
|
|
|
|
|
fluid.layers.log_loss(prediction, labels))
|
|
|
|
@ -159,35 +222,44 @@ class TestImperativeDeepCF(unittest.TestCase):
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace(
|
|
|
|
|
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
|
|
|
|
|
exe.run(startup)
|
|
|
|
|
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
|
|
|
|
|
static_loss = exe.run(
|
|
|
|
|
main,
|
|
|
|
|
feed={
|
|
|
|
|
users.name: users_np[slice:slice + BATCH_SIZE],
|
|
|
|
|
items.name: items_np[slice:slice + BATCH_SIZE],
|
|
|
|
|
labels.name: labels_np[slice:slice + BATCH_SIZE]
|
|
|
|
|
},
|
|
|
|
|
fetch_list=[loss])[0]
|
|
|
|
|
sys.stderr.write('static loss %s\n' % static_loss)
|
|
|
|
|
for e in range(NUM_EPOCHES):
|
|
|
|
|
sys.stderr.write('epoch %d\n' % e)
|
|
|
|
|
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
|
|
|
|
|
if slice + BATCH_SIZE >= users_np.shape[0]:
|
|
|
|
|
break
|
|
|
|
|
static_loss = exe.run(
|
|
|
|
|
main,
|
|
|
|
|
feed={
|
|
|
|
|
users.name: users_np[slice:slice + BATCH_SIZE],
|
|
|
|
|
items.name: items_np[slice:slice + BATCH_SIZE],
|
|
|
|
|
labels.name: labels_np[slice:slice + BATCH_SIZE]
|
|
|
|
|
},
|
|
|
|
|
fetch_list=[loss])[0]
|
|
|
|
|
sys.stderr.write('static loss %s\n' % static_loss)
|
|
|
|
|
|
|
|
|
|
with fluid.imperative.guard():
|
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
|
|
|
|
|
|
deepcf = DeepCF('deepcf')
|
|
|
|
|
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
|
|
|
|
|
prediction = deepcf(
|
|
|
|
|
to_variable(users_np[slice:slice + BATCH_SIZE]),
|
|
|
|
|
to_variable(items_np[slice:slice + BATCH_SIZE]))
|
|
|
|
|
loss = fluid.layers.reduce_sum(
|
|
|
|
|
fluid.layers.log_loss(prediction,
|
|
|
|
|
to_variable(labels_np[slice:slice +
|
|
|
|
|
BATCH_SIZE])))
|
|
|
|
|
loss._backward()
|
|
|
|
|
adam = fluid.optimizer.AdamOptimizer(0.01)
|
|
|
|
|
adam.minimize(loss)
|
|
|
|
|
deepcf.clear_gradients()
|
|
|
|
|
dy_loss = loss._numpy()
|
|
|
|
|
deepcf = DeepCF('deepcf', num_users, num_items, matrix)
|
|
|
|
|
adam = fluid.optimizer.AdamOptimizer(0.01)
|
|
|
|
|
for e in range(NUM_EPOCHES):
|
|
|
|
|
sys.stderr.write('epoch %d\n' % e)
|
|
|
|
|
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
|
|
|
|
|
if slice + BATCH_SIZE >= users_np.shape[0]:
|
|
|
|
|
break
|
|
|
|
|
prediction = deepcf(
|
|
|
|
|
to_variable(users_np[slice:slice + BATCH_SIZE]),
|
|
|
|
|
to_variable(items_np[slice:slice + BATCH_SIZE]))
|
|
|
|
|
loss = fluid.layers.reduce_sum(
|
|
|
|
|
fluid.layers.log_loss(prediction,
|
|
|
|
|
to_variable(labels_np[
|
|
|
|
|
slice:slice + BATCH_SIZE])))
|
|
|
|
|
loss._backward()
|
|
|
|
|
adam.minimize(loss)
|
|
|
|
|
deepcf.clear_gradients()
|
|
|
|
|
dy_loss = loss._numpy()
|
|
|
|
|
sys.stderr.write('dynamic loss: %s %s\n' % (slice, dy_loss))
|
|
|
|
|
|
|
|
|
|
self.assertEqual(static_loss, dy_loss)
|
|
|
|
|
|
|
|
|
|