parent
0ec5a57072
commit
49f8076137
@ -0,0 +1,187 @@
|
||||
<!--TOC -->
|
||||
|
||||
- [Bayesian Graph Collaborative Filtering](#bayesian-graph-collaborative-filtering)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Description of random situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
<!--TOC -->
|
||||
# [Bayesian Graph Collaborative Filtering](#contents)
|
||||
|
||||
Bayesian Graph Collaborative Filtering(BGCF) was proposed in 2020 by Sun J, Guo W, Zhang D et al. By naturally incorporating the
|
||||
uncertainty in the user-item interaction graph shows excellent performance on Amazon recommendation dataset.This is an example of
|
||||
training of BGCF with Amazon-Beauty dataset in MindSpore. More importantly, this is the first open source version for BGCF.
|
||||
|
||||
[Paper](https://dl.acm.org/doi/pdf/10.1145/3394486.3403254): Sun J, Guo W, Zhang D, et al. A Framework for Recommending Accurate and Diverse Items Using Bayesian Graph Convolutional Neural Networks[C]//Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020: 2030-2039.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Specially, BGCF contains two main modules. The first is sampling, which produce sample graphs based in node copying. Another module
|
||||
aggregate the neighbors sampling from nodes consisting of mean aggregator and attention aggregator.
|
||||
|
||||
# [Dataset](#contents)
|
||||
- Dataset size:
|
||||
Statistics of dataset used are summarized as below:
|
||||
|
||||
| | Amazon-Beauty |
|
||||
| ------------------ | -----------------------:|
|
||||
| Task | Recommendation |
|
||||
| # User | 7068 (1 graph) |
|
||||
| # Item | 3570 |
|
||||
| # Interaction | 79506 |
|
||||
| # Training Data | 60818 |
|
||||
| # Test Data | 18688 |
|
||||
| # Density | 0.315% |
|
||||
|
||||
- Data Preparation
|
||||
- Place the dataset to any path you want, the folder should include files as follows(we use Amazon-Beauty dataset as an example)"
|
||||
```
|
||||
.
|
||||
└─data
|
||||
├─ratings_Beauty.csv
|
||||
```
|
||||
|
||||
- Generate dataset in mindrecord format for Amazon-Beauty.
|
||||
```builddoutcfg
|
||||
cd ./scripts
|
||||
# SRC_PATH is the dataset file path you download.
|
||||
sh run_process_data_ascend.sh [SRC_PATH]
|
||||
```
|
||||
|
||||
- Launch
|
||||
```
|
||||
# Generate dataset in mindrecord format for Amazon-Beauty.
|
||||
sh ./run_process_data_ascend.sh ./data
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
To ultilize the strong computation power of Ascend chip, and accelerate the training process, the mixed training method is used. MindSpore is able to cope with FP32 inputs and FP16 operators. In BGCF example, the model is set to FP16 mode except for the loss calculation part.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardward (Ascend)
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website and Dataset is correctly generated, you can start training and evaluation as follows.
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```
|
||||
# run training example with Amazon-Beauty dataset
|
||||
sh run_train_ascend.sh
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─bgcf
|
||||
├─README.md
|
||||
├─scripts
|
||||
| ├─run_process_data_ascend.sh # Generate dataset in mindrecord format
|
||||
| └─run_train_ascend.sh # Launch training
|
||||
|
|
||||
├─src
|
||||
| ├─bgcf.py # BGCF model
|
||||
| ├─callback.py # Callback function
|
||||
| ├─config.py # Training configurations
|
||||
| ├─dataset.py # Data preprocessing
|
||||
| ├─metrics.py # Recommendation metrics
|
||||
| └─utils.py # Utils for training bgcf
|
||||
|
|
||||
└─train.py # Train net
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
|
||||
- config for BGCF dataset
|
||||
|
||||
```python
|
||||
"learning_rate": 0.001, # Learning rate
|
||||
"num_epochs": 600, # Epoch sizes for training
|
||||
"num_neg": 10, # Negative sampling rate
|
||||
"raw_neighs": 40, # Num of sampling neighbors in raw graph
|
||||
"gnew_neighs": 20, # Num of sampling neighbors in sample graph
|
||||
"input_dim": 64, # User and item embedding dimension
|
||||
"l2_coeff": 0.03 # l2 coefficient
|
||||
"neighbor_dropout": [0.0, 0.2, 0.3]# Dropout ratio for different aggregation layer
|
||||
"num_graphs":5 # Num of sample graph
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on Ascend
|
||||
```python
|
||||
sh run_train_ascend.sh
|
||||
```
|
||||
|
||||
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the
|
||||
followings in log.
|
||||
|
||||
```python
|
||||
Epoch 001 iter 12 loss 34696.242
|
||||
Epoch 002 iter 12 loss 34275.508
|
||||
Epoch 003 iter 12 loss 30620.635
|
||||
Epoch 004 iter 12 loss 21628.908
|
||||
|
||||
...
|
||||
Epoch 597 iter 12 loss 3662.3152
|
||||
Epoch 598 iter 12 loss 3640.7612
|
||||
Epoch 599 iter 12 loss 3654.9087
|
||||
Epoch 600 iter 12 loss 3632.4585
|
||||
epoch:600, recall_@10:0.10393, recall_@20:0.15669, ndcg_@10:0.07564, ndcg_@20:0.09343,
|
||||
sedp_@10:0.01936, sedp_@20:0.01544, nov_@10:7.58599, nov_@20:7.79782
|
||||
...
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
## [Performance](#contents)
|
||||
|
||||
| Parameter | BGCF |
|
||||
| ------------------------------------ | ----------------------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 09/04/2020(month/day/year) |
|
||||
| MindSpore Version | 1.0 |
|
||||
| Dataset | Amazon-Beauty |
|
||||
| Training Parameter | epoch=600 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | BPR loss |
|
||||
| Recall@20 | 0.1534 |
|
||||
| NDCG@20 | 0.0912 |
|
||||
| Total time | 30min |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf |
|
||||
|
||||
# [Description of random situation](#contents)
|
||||
|
||||
BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in src/config.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
||||
|
@ -0,0 +1,76 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_process_data_ascend.sh [SRC_PATH] "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
SRC_PATH=$(get_real_path $1)
|
||||
echo $SRC_PATH
|
||||
|
||||
|
||||
if [ ! -d data_mr ]; then
|
||||
mkdir data_mr
|
||||
else
|
||||
echo data_mr exist
|
||||
fi
|
||||
MINDRECORD_PATH=`pwd`/data_mr
|
||||
|
||||
rm -rf ${MINDRECORD_PATH:?}/*
|
||||
INTER_FILE_DIR=$MINDRECORD_PATH/InterFile
|
||||
mkdir -p $INTER_FILE_DIR
|
||||
|
||||
cd ../../../../utils/graph_to_mindrecord || exit
|
||||
|
||||
echo "Start to converting data."
|
||||
python amazon_beauty/converting_data.py --src_path $SRC_PATH --out_path $INTER_FILE_DIR
|
||||
|
||||
echo "Start to generate train_mr."
|
||||
python writer.py --mindrecord_script amazon_beauty \
|
||||
--mindrecord_file "$MINDRECORD_PATH/train_mr" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$INTER_FILE_DIR/user.csv:$INTER_FILE_DIR/item.csv:$INTER_FILE_DIR/rating_train.csv"
|
||||
|
||||
echo "Start to generate test_mr."
|
||||
python writer.py --mindrecord_script amazon_beauty \
|
||||
--mindrecord_file "$MINDRECORD_PATH/test_mr" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$INTER_FILE_DIR/user.csv:$INTER_FILE_DIR/item.csv:$INTER_FILE_DIR/rating_test.csv"
|
||||
|
||||
for id in {0..4}
|
||||
do
|
||||
echo "Start to generate sampled${id}_mr."
|
||||
python writer.py --mindrecord_script amazon_beauty \
|
||||
--mindrecord_file "${MINDRECORD_PATH}/sampled${id}_mr" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$INTER_FILE_DIR/user.csv:$INTER_FILE_DIR/item.csv:$INTER_FILE_DIR/rating_sampled${id}.csv"
|
||||
done
|
||||
|
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
mkdir ./ckpts
|
||||
env > env.log
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
||||
python train.py --datapath=../data_mr &> log &
|
||||
|
||||
cd ..
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,57 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
import argparse
|
||||
|
||||
|
||||
def parser_args():
|
||||
"""Config for BGCF"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--dataset", type=str, default="Beauty")
|
||||
parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr")
|
||||
parser.add_argument("-de", "--device", type=str, default='0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100])
|
||||
parser.add_argument('--test_ratio', type=float, default=0.2)
|
||||
parser.add_argument('--val_ratio', type=float, default=None)
|
||||
parser.add_argument('-w', '--workers', type=int, default=10)
|
||||
|
||||
parser.add_argument("-eps", "--epsilon", type=float, default=1e-8)
|
||||
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3)
|
||||
parser.add_argument("-l2", "--l2", type=float, default=0.03)
|
||||
parser.add_argument("-wd", "--weight_decay", type=float, default=0.01)
|
||||
parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh'])
|
||||
parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3])
|
||||
parser.add_argument("-log", "--log_name", type=str, default='test')
|
||||
|
||||
parser.add_argument("-e", "--num_epoch", type=int, default=600)
|
||||
parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128])
|
||||
parser.add_argument("-b", "--batch_pairs", type=int, default=5000)
|
||||
parser.add_argument('--eval_interval', type=int, default=20)
|
||||
|
||||
parser.add_argument("-neg", "--num_neg", type=int, default=10)
|
||||
parser.add_argument('-max', '--max_degree', type=str, default='[128,128]')
|
||||
parser.add_argument("-g1", "--raw_neighs", type=int, default=40)
|
||||
parser.add_argument("-g2", "--gnew_neighs", type=int, default=20)
|
||||
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64)
|
||||
parser.add_argument('-dist', '--distance', type=str, default='iou')
|
||||
parser.add_argument('--dist_reg', type=float, default=0.003)
|
||||
|
||||
parser.add_argument('-ng', '--num_graphs', type=int, default=5)
|
||||
parser.add_argument('-geps', '--graph_epsilon', type=float, default=0.01)
|
||||
|
||||
return parser.parse_args()
|
@ -0,0 +1,191 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
preprocess raw data; generate batched data and sample neighbors on graph for training and test;
|
||||
Amazon Beauty datasets are supported by our example, the original versions of these datasets are as follows:
|
||||
@article{Amazon Beauty,
|
||||
title = {Ups and Downs: Modeling the Visual Evolution of Fashion Trends with One-Class Collaborative Filtering},
|
||||
author = {R. He, J. McAuley},
|
||||
journal = {WWW},
|
||||
year = {2016},
|
||||
url = {http://jmcauley.ucsd.edu/data/amazon}
|
||||
}
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
class RandomBatchedSampler(ds.Sampler):
|
||||
"""RandomBatchedSampler generate random sequence without replacement in a batched manner"""
|
||||
|
||||
sampled_graph_index = 0
|
||||
|
||||
def __init__(self, index_range, num_edges_per_sample):
|
||||
super().__init__()
|
||||
self.index_range = index_range
|
||||
self.num_edges_per_sample = num_edges_per_sample
|
||||
|
||||
def __iter__(self):
|
||||
self.sampled_graph_index += 1
|
||||
indices = [i for i in range(self.index_range)]
|
||||
np.random.shuffle(indices)
|
||||
for i in range(0, self.index_range, self.num_edges_per_sample):
|
||||
if i + self.num_edges_per_sample <= self.index_range:
|
||||
result = indices[i: i + self.num_edges_per_sample]
|
||||
result.append(self.sampled_graph_index)
|
||||
yield result
|
||||
|
||||
|
||||
class TrainGraphDataset():
|
||||
"""Sample node neighbors on graphs for training"""
|
||||
|
||||
def __init__(self, train_graph, sampled_graphs, batch_num, num_samples, num_bgcn_neigh, num_neg):
|
||||
self.g = train_graph
|
||||
self.batch_num = batch_num
|
||||
self.sampled_graphs = sampled_graphs
|
||||
self.sampled_graph_num = len(sampled_graphs)
|
||||
self.num_samples = num_samples
|
||||
self.num_bgcn_neigh = num_bgcn_neigh
|
||||
self.num_neg = num_neg
|
||||
|
||||
def __len__(self):
|
||||
return self.g.graph_info()['edge_num'][0] // self.batch_num
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Sample negative items with their neighbors, user neighbors, pos item neighbors
|
||||
based on the user-item pairs
|
||||
"""
|
||||
sampled_graph_index = index[-1] % self.sampled_graph_num
|
||||
index = index[0:-1]
|
||||
train_graph = self.g
|
||||
sampled_graph = self.sampled_graphs[sampled_graph_index]
|
||||
|
||||
rating = train_graph.get_nodes_from_edges(index.astype(np.int32))
|
||||
users = rating[:, 0]
|
||||
|
||||
u_group_nodes = train_graph.get_sampled_neighbors(
|
||||
node_list=users, neighbor_nums=[1], neighbor_types=[0])
|
||||
pos_users = u_group_nodes[:, 1]
|
||||
u_group_nodes = np.concatenate((users, pos_users), axis=0)
|
||||
u_group_nodes = u_group_nodes.reshape(-1,).tolist()
|
||||
u_neighs = train_graph.get_sampled_neighbors(
|
||||
node_list=u_group_nodes, neighbor_nums=[self.num_samples], neighbor_types=[1])
|
||||
u_neighs = u_neighs[:, 1:]
|
||||
u_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||
node_list=u_group_nodes, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[1])
|
||||
u_gnew_neighs = u_gnew_neighs[:, 1:]
|
||||
|
||||
items = rating[:, 1]
|
||||
i_group_nodes = train_graph.get_sampled_neighbors(
|
||||
node_list=items, neighbor_nums=[1], neighbor_types=[1])
|
||||
pos_items = i_group_nodes[:, 1]
|
||||
i_group_nodes = np.concatenate((items, pos_items), axis=0)
|
||||
i_group_nodes = i_group_nodes.reshape(-1,).tolist()
|
||||
i_neighs = train_graph.get_sampled_neighbors(
|
||||
node_list=i_group_nodes, neighbor_nums=[self.num_samples], neighbor_types=[0])
|
||||
i_neighs = i_neighs[:, 1:]
|
||||
i_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||
node_list=i_group_nodes, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[0])
|
||||
i_gnew_neighs = i_gnew_neighs[:, 1:]
|
||||
|
||||
neg_item_id = train_graph.get_neg_sampled_neighbors(
|
||||
node_list=users, neg_neighbor_num=self.num_neg, neg_neighbor_type=1)
|
||||
neg_item_id = neg_item_id[:, 1:]
|
||||
neg_group_nodes = neg_item_id.reshape(-1,)
|
||||
neg_neighs = train_graph.get_sampled_neighbors(
|
||||
node_list=neg_group_nodes, neighbor_nums=[self.num_samples], neighbor_types=[0])
|
||||
neg_neighs = neg_neighs[:, 1:]
|
||||
neg_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||
node_list=neg_group_nodes, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[0])
|
||||
neg_gnew_neighs = neg_gnew_neighs[:, 1:]
|
||||
|
||||
return users, items, neg_item_id, pos_users, pos_items, u_group_nodes, u_neighs, u_gnew_neighs, \
|
||||
i_group_nodes, i_neighs, i_gnew_neighs, neg_group_nodes, neg_neighs, neg_gnew_neighs
|
||||
|
||||
|
||||
class TestGraphDataset():
|
||||
"""Sample node neighbors on graphs for test"""
|
||||
|
||||
def __init__(self, g, sampled_graphs, num_samples, num_bgcn_neigh, num_neg):
|
||||
self.g = g
|
||||
self.sampled_graphs = sampled_graphs
|
||||
self.sampled_graph_index = 0
|
||||
self.num_samples = num_samples
|
||||
self.num_bgcn_neigh = num_bgcn_neigh
|
||||
self.num_neg = num_neg
|
||||
self.num_user = self.g.graph_info()["node_num"][0]
|
||||
self.num_item = self.g.graph_info()["node_num"][1]
|
||||
|
||||
def random_select_sampled_graph(self):
|
||||
self.sampled_graph_index = np.random.randint(len(self.sampled_graphs))
|
||||
|
||||
def get_user_sapmled_neighbor(self):
|
||||
"""Sample all users neighbors for test"""
|
||||
users = np.arange(self.num_user, dtype=np.int32)
|
||||
u_neighs = self.g.get_sampled_neighbors(
|
||||
node_list=users, neighbor_nums=[self.num_samples], neighbor_types=[1])
|
||||
u_neighs = u_neighs[:, 1:]
|
||||
sampled_graph = self.sampled_graphs[self.sampled_graph_index]
|
||||
u_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||
node_list=users, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[1])
|
||||
u_gnew_neighs = u_gnew_neighs[:, 1:]
|
||||
return u_neighs, u_gnew_neighs
|
||||
|
||||
def get_item_sampled_neighbor(self):
|
||||
"""Sample all items neighbors for test"""
|
||||
items = np.arange(self.num_user, self.num_user + self.num_item, dtype=np.int32)
|
||||
i_neighs = self.g.get_sampled_neighbors(
|
||||
node_list=items, neighbor_nums=[self.num_samples], neighbor_types=[0])
|
||||
i_neighs = i_neighs[:, 1:]
|
||||
|
||||
sampled_graph = self.sampled_graphs[self.sampled_graph_index]
|
||||
i_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||
node_list=items, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[0])
|
||||
i_gnew_neighs = i_gnew_neighs[:, 1:]
|
||||
return i_neighs, i_gnew_neighs
|
||||
|
||||
|
||||
def load_graph(data_path):
|
||||
"""Load train graph, test graph and sampled graph"""
|
||||
train_graph = ds.GraphData(
|
||||
data_path + "/train_mr", num_parallel_workers=8)
|
||||
|
||||
test_graph = ds.GraphData(
|
||||
data_path + "/test_mr", num_parallel_workers=8)
|
||||
|
||||
sampled_graph_list = []
|
||||
for i in range(0, 5):
|
||||
sampled_graph = ds.GraphData(
|
||||
data_path + "/sampled" + str(i) + "_mr", num_parallel_workers=8)
|
||||
sampled_graph_list.append(sampled_graph)
|
||||
|
||||
return train_graph, test_graph, sampled_graph_list
|
||||
|
||||
|
||||
def create_dataset(train_graph, sampled_graph_list, batch_size=32, repeat_size=1, num_samples=40, num_bgcn_neigh=20,
|
||||
num_neg=10):
|
||||
"""Data generator for training"""
|
||||
edge_num = train_graph.graph_info()['edge_num'][0]
|
||||
out_column_names = ["users", "items", "neg_item_id", "pos_users", "pos_items", "u_group_nodes", "u_neighs",
|
||||
"u_gnew_neighs", "i_group_nodes", "i_neighs", "i_gnew_neighs", "neg_group_nodes",
|
||||
"neg_neighs", "neg_gnew_neighs"]
|
||||
train_graph_dataset = TrainGraphDataset(
|
||||
train_graph, sampled_graph_list, batch_size, num_samples, num_bgcn_neigh, num_neg)
|
||||
dataset = ds.GeneratorDataset(source=train_graph_dataset, column_names=out_column_names,
|
||||
sampler=RandomBatchedSampler(edge_num, batch_size), num_parallel_workers=8)
|
||||
dataset = dataset.repeat(repeat_size)
|
||||
|
||||
return dataset
|
@ -0,0 +1,184 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Recommendation metrics
|
||||
"""
|
||||
import math
|
||||
import heapq
|
||||
from multiprocessing import Pool
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.utils import convert_item_id
|
||||
|
||||
|
||||
def ndcg_k(actual, predicted, topk):
|
||||
"""Calculates the normalized discounted cumulative gain at k"""
|
||||
idcg = idcg_k(actual, topk)
|
||||
res = 0
|
||||
|
||||
dcg_k = sum([int(predicted[j] in set(actual)) / math.log(j + 2, 2) for j in range(topk)])
|
||||
res += dcg_k / idcg
|
||||
return res
|
||||
|
||||
|
||||
def idcg_k(actual, k):
|
||||
"""Calculates the ideal discounted cumulative gain at k"""
|
||||
res = sum([1.0 / math.log(i + 2, 2) for i in range(min(k, len(actual)))])
|
||||
return 1.0 if not res else res
|
||||
|
||||
|
||||
def recall_at_k_2(r, k, all_pos_num):
|
||||
"""Calculates the recall at k"""
|
||||
r = np.asfarray(r)[:k]
|
||||
return np.sum(r) / all_pos_num
|
||||
|
||||
|
||||
def novelty_at_k(topk_items, item_degree_dict, num_user, k):
|
||||
"""Calculate the novelty at k"""
|
||||
avg_nov = []
|
||||
for item in topk_items[:k]:
|
||||
avg_nov.append(-np.log2(item_degree_dict[item] / num_user))
|
||||
return np.mean(avg_nov)
|
||||
|
||||
|
||||
def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
|
||||
"""Return the n largest score from the item_score by heap algorithm"""
|
||||
item_score = {}
|
||||
for i in test_items:
|
||||
item_score[i] = rating[i]
|
||||
|
||||
K_max = max(Ks)
|
||||
K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
|
||||
|
||||
r = []
|
||||
for i in K_max_item_score:
|
||||
if i in user_pos_test:
|
||||
r.append(1)
|
||||
else:
|
||||
r.append(0)
|
||||
return r, K_max_item_score
|
||||
|
||||
|
||||
def get_performance(user_pos_test, r, K_max_item, item_degree_dict, num_user, Ks):
|
||||
"""Wraps the model metrics"""
|
||||
recall, ndcg, novelty = [], [], []
|
||||
for K in Ks:
|
||||
recall.append(recall_at_k_2(r, K, len(user_pos_test)))
|
||||
ndcg.append(ndcg_k(user_pos_test, K_max_item, K))
|
||||
novelty.append(novelty_at_k(K_max_item, item_degree_dict, num_user, K))
|
||||
return {'recall': np.array(recall), 'ndcg': np.array(ndcg), 'nov': np.array(novelty)}
|
||||
|
||||
|
||||
class BGCFEvaluate:
|
||||
"""
|
||||
Evaluate the model recommendation performance
|
||||
"""
|
||||
|
||||
def __init__(self, parser, train_graph, test_graph, Ks):
|
||||
self.num_user = train_graph.graph_info()["node_num"][0]
|
||||
self.num_item = train_graph.graph_info()["node_num"][1]
|
||||
self.Ks = Ks
|
||||
|
||||
self.test_set = []
|
||||
self.train_set = []
|
||||
for i in range(0, self.num_user):
|
||||
train_item = train_graph.get_all_neighbors(node_list=[i], neighbor_type=1)
|
||||
train_item = train_item[1:]
|
||||
self.train_set.append(train_item)
|
||||
for i in range(0, self.num_user):
|
||||
test_item = test_graph.get_all_neighbors(node_list=[i], neighbor_type=1)
|
||||
test_item = test_item[1:]
|
||||
self.test_set.append(test_item)
|
||||
self.train_set = convert_item_id(self.train_set, self.num_user).tolist()
|
||||
self.test_set = convert_item_id(self.test_set, self.num_user).tolist()
|
||||
|
||||
self.item_deg_dict = {}
|
||||
self.item_full_set = []
|
||||
for i in range(self.num_user, self.num_user + self.num_item):
|
||||
train_users = train_graph.get_all_neighbors(node_list=[i], neighbor_type=0)
|
||||
train_users = train_users.tolist()
|
||||
if isinstance(train_users, int):
|
||||
train_users = []
|
||||
else:
|
||||
train_users = train_users[1:]
|
||||
self.item_deg_dict[i - self.num_user] = len(train_users)
|
||||
test_users = test_graph.get_all_neighbors(node_list=[i], neighbor_type=0)
|
||||
test_users = test_users.tolist()
|
||||
if isinstance(test_users, int):
|
||||
test_users = []
|
||||
else:
|
||||
test_users = test_users[1:]
|
||||
self.item_full_set.append(train_users + test_users)
|
||||
|
||||
def test_one_user(self, x):
|
||||
"""Calculate one user metrics"""
|
||||
rating = x[0]
|
||||
u = x[1]
|
||||
|
||||
training_items = self.train_set[u]
|
||||
|
||||
user_pos_test = self.test_set[u]
|
||||
|
||||
all_items = set(range(self.num_item))
|
||||
|
||||
test_items = list(all_items - set(training_items))
|
||||
|
||||
r, k_max_items = ranklist_by_heapq(user_pos_test, test_items, rating, self.Ks)
|
||||
|
||||
return get_performance(user_pos_test, r, k_max_items, self.item_deg_dict, self.num_user, self.Ks), \
|
||||
[k_max_items[:self.Ks[x]] for x in range(len(self.Ks))]
|
||||
|
||||
def eval_with_rep(self, user_rep, item_rep, parser):
|
||||
"""Evaluation with user and item rep"""
|
||||
result = {'recall': np.zeros(len(self.Ks)), 'ndcg': np.zeros(len(self.Ks)),
|
||||
'nov': np.zeros(len(self.Ks))}
|
||||
pool = Pool(parser.workers)
|
||||
user_indexes = np.arange(self.num_user)
|
||||
|
||||
rating_preds = user_rep @ item_rep.transpose()
|
||||
user_rating_uid = zip(rating_preds, user_indexes)
|
||||
all_result = pool.map(self.test_one_user, user_rating_uid)
|
||||
|
||||
top20 = []
|
||||
|
||||
for re in all_result:
|
||||
result['recall'] += re[0]['recall'] / self.num_user
|
||||
result['ndcg'] += re[0]['ndcg'] / self.num_user
|
||||
result['nov'] += re[0]['nov'] / self.num_user
|
||||
top20.append(re[1][2])
|
||||
|
||||
pool.close()
|
||||
|
||||
sedp = [[] for i in range(len(self.Ks) - 1)]
|
||||
|
||||
num_all_links = np.sum([len(x) for x in self.item_full_set])
|
||||
|
||||
for k in range(len(self.Ks) - 1):
|
||||
for u in range(self.num_user):
|
||||
diff = []
|
||||
pred_items_at_k = all_result[u][1][k]
|
||||
for item in pred_items_at_k:
|
||||
if item in self.test_set[u]:
|
||||
avg_prob_all_user = len(self.item_full_set[item]) / num_all_links
|
||||
diff.append(max((self.Ks[k] - pred_items_at_k.index(item) - 1)
|
||||
/ (self.Ks[k] - 1) - avg_prob_all_user, 0))
|
||||
one_user_sedp = sum(diff) / self.Ks[k]
|
||||
sedp[k].append(one_user_sedp)
|
||||
|
||||
sedp = np.array(sedp).mean(1)
|
||||
|
||||
return result['recall'].tolist(), result['ndcg'].tolist(), \
|
||||
[sedp[1], sedp[2]], result['nov'].tolist()
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for training BGCF"""
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import shutil
|
||||
import pickle as pkl
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_pickle(path, name):
|
||||
"""Load pickle"""
|
||||
with open(path + name, 'rb') as f:
|
||||
return pkl.load(f, encoding='latin1')
|
||||
|
||||
|
||||
class BGCFLogger:
|
||||
"""log the output metrics"""
|
||||
|
||||
def __init__(self, logname, now, foldername, copy):
|
||||
self.terminal = sys.stdout
|
||||
self.file = None
|
||||
|
||||
path = os.path.join(foldername, logname, now)
|
||||
os.makedirs(path)
|
||||
|
||||
if copy:
|
||||
filenames = glob.glob('*.py')
|
||||
for filename in filenames:
|
||||
shutil.copy(filename, path)
|
||||
|
||||
def open(self, file, mode=None):
|
||||
if mode is None:
|
||||
mode = 'w'
|
||||
self.file = open(file, mode)
|
||||
|
||||
def write(self, message, is_terminal=True, is_file=True):
|
||||
"""Write log"""
|
||||
if '\r' in message:
|
||||
is_file = False
|
||||
|
||||
if is_terminal:
|
||||
self.terminal.write(message)
|
||||
self.terminal.flush()
|
||||
|
||||
if is_file:
|
||||
self.file.write(message)
|
||||
self.file.flush()
|
||||
|
||||
|
||||
def convert_item_id(item_list, num_user):
|
||||
"""Convert the graph node id into item id"""
|
||||
return np.array(item_list) - num_user
|
@ -0,0 +1,173 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
BGCF training script.
|
||||
"""
|
||||
import os
|
||||
import datetime
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
|
||||
from src.bgcf import BGCF
|
||||
from src.metrics import BGCFEvaluate
|
||||
from src.config import parser_args
|
||||
from src.utils import BGCFLogger, convert_item_id
|
||||
from src.callback import ForwardBGCF, TrainBGCF, TestBGCF
|
||||
from src.dataset import load_graph, create_dataset, TestGraphDataset
|
||||
|
||||
|
||||
def train_and_eval():
|
||||
"""Train and eval"""
|
||||
num_user = train_graph.graph_info()["node_num"][0]
|
||||
num_item = train_graph.graph_info()["node_num"][1]
|
||||
num_pairs = train_graph.graph_info()['edge_num'][0]
|
||||
|
||||
bgcfnet = BGCF([parser.input_dim, num_user, num_item],
|
||||
parser.embedded_dimension,
|
||||
parser.activation,
|
||||
parser.neighbor_dropout,
|
||||
num_user,
|
||||
num_item,
|
||||
parser.input_dim)
|
||||
|
||||
train_net = TrainBGCF(bgcfnet, parser.num_neg, parser.l2, parser.learning_rate,
|
||||
parser.epsilon, parser.dist_reg)
|
||||
train_net.set_train(True)
|
||||
|
||||
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
|
||||
|
||||
itr = train_ds.create_dict_iterator(parser.num_epoch)
|
||||
num_iter = int(num_pairs / parser.batch_pairs)
|
||||
|
||||
for _epoch in range(1, parser.num_epoch + 1):
|
||||
|
||||
iter_num = 1
|
||||
|
||||
for data in itr:
|
||||
|
||||
u_id = Tensor(data["users"], mstype.int32)
|
||||
pos_item_id = Tensor(convert_item_id(data["items"], num_user), mstype.int32)
|
||||
neg_item_id = Tensor(convert_item_id(data["neg_item_id"], num_user), mstype.int32)
|
||||
pos_users = Tensor(data["pos_users"], mstype.int32)
|
||||
pos_items = Tensor(convert_item_id(data["pos_items"], num_user), mstype.int32)
|
||||
|
||||
u_group_nodes = Tensor(data["u_group_nodes"], mstype.int32)
|
||||
u_neighs = Tensor(convert_item_id(data["u_neighs"], num_user), mstype.int32)
|
||||
u_gnew_neighs = Tensor(convert_item_id(data["u_gnew_neighs"], num_user), mstype.int32)
|
||||
|
||||
i_group_nodes = Tensor(convert_item_id(data["i_group_nodes"], num_user), mstype.int32)
|
||||
i_neighs = Tensor(data["i_neighs"], mstype.int32)
|
||||
i_gnew_neighs = Tensor(data["i_gnew_neighs"], mstype.int32)
|
||||
|
||||
neg_group_nodes = Tensor(convert_item_id(data["neg_group_nodes"], num_user), mstype.int32)
|
||||
neg_neighs = Tensor(data["neg_neighs"], mstype.int32)
|
||||
neg_gnew_neighs = Tensor(data["neg_gnew_neighs"], mstype.int32)
|
||||
|
||||
train_loss = train_net(u_id,
|
||||
pos_item_id,
|
||||
neg_item_id,
|
||||
pos_users,
|
||||
pos_items,
|
||||
u_group_nodes,
|
||||
u_neighs,
|
||||
u_gnew_neighs,
|
||||
i_group_nodes,
|
||||
i_neighs,
|
||||
i_gnew_neighs,
|
||||
neg_group_nodes,
|
||||
neg_neighs,
|
||||
neg_gnew_neighs)
|
||||
|
||||
if iter_num == num_iter:
|
||||
print('Epoch', '%03d' % _epoch, 'iter', '%02d' % iter_num,
|
||||
'loss',
|
||||
'{}'.format(train_loss))
|
||||
iter_num += 1
|
||||
|
||||
if _epoch % parser.eval_interval == 0:
|
||||
if os.path.exists("ckpts/bgcf.ckpt"):
|
||||
os.remove("ckpts/bgcf.ckpt")
|
||||
save_checkpoint(bgcfnet, "ckpts/bgcf.ckpt")
|
||||
|
||||
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
|
||||
parser.embedded_dimension,
|
||||
parser.activation,
|
||||
[0.0, 0.0, 0.0],
|
||||
num_user,
|
||||
num_item,
|
||||
parser.input_dim)
|
||||
|
||||
load_checkpoint("ckpts/bgcf.ckpt", net=bgcfnet_test)
|
||||
|
||||
forward_net = ForwardBGCF(bgcfnet_test)
|
||||
user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, parser.input_dim, test_graph_dataset)
|
||||
|
||||
test_recall_bgcf, test_ndcg_bgcf, \
|
||||
test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)
|
||||
|
||||
if parser.log_name:
|
||||
log.write(
|
||||
'epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||
test_recall_bgcf[1],
|
||||
test_recall_bgcf[2],
|
||||
test_ndcg_bgcf[1],
|
||||
test_ndcg_bgcf[2],
|
||||
test_sedp[0],
|
||||
test_sedp[1],
|
||||
test_nov[1],
|
||||
test_nov[2]))
|
||||
else:
|
||||
print('epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||
test_recall_bgcf[1],
|
||||
test_recall_bgcf[2],
|
||||
test_ndcg_bgcf[1],
|
||||
test_ndcg_bgcf[2],
|
||||
test_sedp[0],
|
||||
test_sedp[1],
|
||||
test_nov[1],
|
||||
test_nov[2]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False)
|
||||
|
||||
parser = parser_args()
|
||||
|
||||
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
|
||||
train_ds = create_dataset(train_graph, sampled_graph_list, batch_size=parser.batch_pairs)
|
||||
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,
|
||||
num_bgcn_neigh=parser.gnew_neighs,
|
||||
num_neg=parser.num_neg)
|
||||
|
||||
if parser.log_name:
|
||||
now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S")
|
||||
name = "bgcf" + '-' + parser.log_name + '-' + parser.dataset
|
||||
log_save_path = './log-files/' + name + '/' + now
|
||||
log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False)
|
||||
log.open(log_save_path + '/log.train.txt', mode='a')
|
||||
for arg in vars(parser):
|
||||
log.write(arg + '=' + str(getattr(parser, arg)) + '\n')
|
||||
else:
|
||||
for arg in vars(parser):
|
||||
print(arg + '=' + str(getattr(parser, arg)))
|
||||
|
||||
train_and_eval()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,75 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
import os
|
||||
import csv
|
||||
|
||||
args = os.environ['graph_api_args'].split(':')
|
||||
USER_FILE = args[0]
|
||||
ITEM_FILE = args[1]
|
||||
RATING_FILE = args[2]
|
||||
|
||||
node_profile = (0, [], [])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
with open(USER_FILE) as user_file:
|
||||
user_reader = csv.reader(user_file, delimiter=',')
|
||||
line_count = 0
|
||||
for row in user_reader:
|
||||
node = {'id': int(row[1]), 'type': 0}
|
||||
yield node
|
||||
line_count += 1
|
||||
print('Processed {} lines for users.'.format(line_count))
|
||||
|
||||
with open(ITEM_FILE) as item_file:
|
||||
item_reader = csv.reader(item_file, delimiter=',')
|
||||
line_count = 0
|
||||
for row in item_reader:
|
||||
node = {'id': int(row[1]), 'type': 1,}
|
||||
yield node
|
||||
line_count += 1
|
||||
print('Processed {} lines for items.'.format(line_count))
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
"""
|
||||
Generate edge data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
with open(RATING_FILE) as rating_file:
|
||||
rating_reader = csv.reader(rating_file, delimiter=',')
|
||||
line_count = 0
|
||||
for row in rating_reader:
|
||||
if line_count == 0:
|
||||
line_count += 1
|
||||
continue
|
||||
edge = {'id': line_count - 1, 'src_id': int(row[0]), 'dst_id': int(row[1]), 'type': int(row[2])}
|
||||
yield edge
|
||||
line_count += 1
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
Loading…
Reference in new issue