|
|
|
@ -15,7 +15,6 @@
|
|
|
|
|
"""
|
|
|
|
|
BGCF training script.
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
@ -102,12 +101,12 @@ def train():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
parser = parser_args()
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
|
device_target="Ascend",
|
|
|
|
|
save_graphs=False)
|
|
|
|
|
save_graphs=False,
|
|
|
|
|
device_id=parser.device)
|
|
|
|
|
|
|
|
|
|
parser = parser_args()
|
|
|
|
|
os.environ['DEVICE_ID'] = parser.device
|
|
|
|
|
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
|
|
|
|
|
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,
|
|
|
|
|
num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg)
|
|
|
|
|