fix device id bug

pull/11257/head
zhanke 4 years ago
parent 72dc7b208c
commit 432984d86f

@ -15,7 +15,6 @@
"""
BGCF training script.
"""
import os
import time
from mindspore import Tensor
@ -102,12 +101,12 @@ def train():
if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False)
save_graphs=False,
device_id=parser.device)
parser = parser_args()
os.environ['DEVICE_ID'] = parser.device
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,
num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg)

Loading…
Cancel
Save