!8756 [ModelZoo]Change gnn seed

From: @zhan_ke
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
pull/8756/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit adfd70e9d8

@ -19,7 +19,6 @@ import os
import numpy as np import numpy as np
import mindspore.context as context import mindspore.context as context
from mindspore.train.serialization import save_checkpoint, load_checkpoint from mindspore.train.serialization import save_checkpoint, load_checkpoint
from mindspore.common import set_seed
from mindspore import Tensor from mindspore import Tensor
from src.config import GatConfig from src.config import GatConfig
@ -27,7 +26,6 @@ from src.dataset import load_and_process
from src.gat import GAT from src.gat import GAT
from src.utils import LossAccuracyWrapper, TrainGAT from src.utils import LossAccuracyWrapper, TrainGAT
set_seed(0)
def train(): def train():
"""Train GAT model.""" """Train GAT model."""

@ -27,7 +27,6 @@ from matplotlib import animation
from sklearn import manifold from sklearn import manifold
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import set_seed
from mindspore.train.serialization import save_checkpoint, load_checkpoint from mindspore.train.serialization import save_checkpoint, load_checkpoint
from src.gcn import GCN from src.gcn import GCN
@ -51,7 +50,6 @@ def train():
"""Train model.""" """Train model."""
parser = argparse.ArgumentParser(description='GCN') parser = argparse.ArgumentParser(description='GCN')
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory') parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training') parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation') parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
@ -60,7 +58,6 @@ def train():
if not os.path.exists("ckpts"): if not os.path.exists("ckpts"):
os.mkdir("ckpts") os.mkdir("ckpts")
set_seed(args_opt.seed)
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False) device_target="Ascend", save_graphs=False)
config = ConfigGCN() config = ConfigGCN()

Loading…
Cancel
Save