!8756 [ModelZoo]Change gnn seed

From: @zhan_ke
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
pull/8756/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit adfd70e9d8

@ -19,7 +19,6 @@ import os
import numpy as np
import mindspore.context as context
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from mindspore.common import set_seed
from mindspore import Tensor
from src.config import GatConfig
@ -27,7 +26,6 @@ from src.dataset import load_and_process
from src.gat import GAT
from src.utils import LossAccuracyWrapper, TrainGAT
set_seed(0)
def train():
"""Train GAT model."""

@ -27,7 +27,6 @@ from matplotlib import animation
from sklearn import manifold
from mindspore import context
from mindspore import Tensor
from mindspore.common import set_seed
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from src.gcn import GCN
@ -51,7 +50,6 @@ def train():
"""Train model."""
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
@ -60,7 +58,6 @@ def train():
if not os.path.exists("ckpts"):
os.mkdir("ckpts")
set_seed(args_opt.seed)
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False)
config = ConfigGCN()

Loading…
Cancel
Save