|
|
|
@ -27,7 +27,6 @@ from matplotlib import animation
|
|
|
|
|
from sklearn import manifold
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore.common import set_seed
|
|
|
|
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
|
|
|
|
|
|
|
|
|
from src.gcn import GCN
|
|
|
|
@ -51,7 +50,6 @@ def train():
|
|
|
|
|
"""Train model."""
|
|
|
|
|
parser = argparse.ArgumentParser(description='GCN')
|
|
|
|
|
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
|
|
|
|
|
parser.add_argument('--seed', type=int, default=0, help='Random seed')
|
|
|
|
|
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
|
|
|
|
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
|
|
|
|
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
|
|
|
@ -60,7 +58,6 @@ def train():
|
|
|
|
|
if not os.path.exists("ckpts"):
|
|
|
|
|
os.mkdir("ckpts")
|
|
|
|
|
|
|
|
|
|
set_seed(args_opt.seed)
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
|
device_target="Ascend", save_graphs=False)
|
|
|
|
|
config = ConfigGCN()
|
|
|
|
|