!11257 [ModelZoo]fix bgcf device id bug

From: @zhan_ke
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @c_34
pull/11257/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 61d6d85347

@ -15,7 +15,6 @@
"""
BGCF training script.
"""
import os
import time
from mindspore import Tensor
@ -102,12 +101,12 @@ def train():
if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False)
save_graphs=False,
device_id=parser.device)
parser = parser_args()
os.environ['DEVICE_ID'] = parser.device
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,
num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg)

Loading…
Cancel
Save