diff --git a/model_zoo/gcn/README.md b/model_zoo/gcn/README.md new file mode 100644 index 0000000000..a2e97c7c3a --- /dev/null +++ b/model_zoo/gcn/README.md @@ -0,0 +1,113 @@ +# GCN Example + +## Description + +This is an example of training GCN with Cora and Citeseer dataset in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset Cora or Citeseer provided by /kimiyoung/planetoid from github. + +> Place the dataset to any path you want, the folder should include files as follows(we use Cora dataset as an example): + +``` +. +└─data + ├─ind.cora.allx + ├─ind.cora.ally + ├─ind.cora.graph + ├─ind.cora.test.index + ├─ind.cora.tx + ├─ind.cora.ty + ├─ind.cora.x + └─ind.cora.y +``` + +> Generate dataset in mindrecord format for cora or citeseer. +>> Usage +```buildoutcfg +cd ./scripts +# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer +sh run_process_data.sh [SRC_PATH] [DATASET_NAME] +``` + +>> Launch +``` +#Generate dataset in mindrecord format for cora +sh run_process_data.sh cora +#Generate dataset in mindrecord format for citeseer +sh run_process_data.sh citeseer +``` + +## Structure + +```shell +. +└─gcn + ├─README.md + ├─scripts + | ├─run_process_data.sh # Generate dataset in mindrecord format + | └─run_train.sh # Launch training + | + ├─src + | ├─config.py # Parameter configuration + | ├─dataset.py # Data preprocessin + | ├─gcn.py # GCN backbone + | └─metrics.py # Loss and accuracy + | + └─train.py # Train net +``` + +## Parameter configuration + +Parameters for training can be set in config.py. + +``` +"learning_rate": 0.01, # Learning rate +"epochs": 200, # Epoch sizes for training +"hidden1": 16, # Hidden size for the first graph convolution layer +"dropout": 0.5, # Dropout ratio for the first graph convolution layer +"weight_decay": 5e-4, # Weight decay for the parameter of the first graph convolution layer +"early_stopping": 10, # Tolerance for early stopping +``` + +## Running the example + +### Train + +#### Usage + +``` +# run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer +sh run_train.sh [DATASET_NAME] +``` + +#### Launch + +```bash +sh run_train.sh cora +``` + +#### Result + +Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the followings in log. + + +``` +Epoch: 0001 train_loss= 1.95373 train_acc= 0.09286 val_loss= 1.95075 val_acc= 0.20200 time= 7.25737 +Epoch: 0002 train_loss= 1.94812 train_acc= 0.32857 val_loss= 1.94717 val_acc= 0.34000 time= 0.00438 +Epoch: 0003 train_loss= 1.94249 train_acc= 0.47857 val_loss= 1.94337 val_acc= 0.43000 time= 0.00428 +Epoch: 0004 train_loss= 1.93550 train_acc= 0.55000 val_loss= 1.93957 val_acc= 0.46400 time= 0.00421 +Epoch: 0005 train_loss= 1.92617 train_acc= 0.67143 val_loss= 1.93558 val_acc= 0.45400 time= 0.00430 +... +Epoch: 0196 train_loss= 0.60326 train_acc= 0.97857 val_loss= 1.05155 val_acc= 0.78200 time= 0.00418 +Epoch: 0197 train_loss= 0.60377 train_acc= 0.97143 val_loss= 1.04940 val_acc= 0.78000 time= 0.00418 +Epoch: 0198 train_loss= 0.60680 train_acc= 0.95000 val_loss= 1.04847 val_acc= 0.78000 time= 0.00414 +Epoch: 0199 train_loss= 0.61920 train_acc= 0.96429 val_loss= 1.04797 val_acc= 0.78400 time= 0.00413 +Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.78600 time= 0.00415 +Optimization Finished! +Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083 +... +``` diff --git a/model_zoo/gcn/scripts/run_process_data.sh b/model_zoo/gcn/scripts/run_process_data.sh new file mode 100755 index 0000000000..4501f3c67f --- /dev/null +++ b/model_zoo/gcn/scripts/run_process_data.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_train.sh [SRC_PATH] [DATASET_NAME]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +SRC_PATH=$(get_real_path $1) +echo $SRC_PATH + +DATASET_NAME=$2 +echo $DATASET_NAME + +if [ ! -d data_mr ]; then + mkdir data_mr +else + echo data_mr exist +fi +MINDRECORD_PATH=`pwd`/data_mr + +rm -f $MINDRECORD_PATH/* + +cd ../../../example/graph_to_mindrecord || exit + +python writer.py --mindrecord_script $DATASET_NAME \ +--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \ +--mindrecord_partitions 1 \ +--mindrecord_header_size_by_bit 18 \ +--mindrecord_page_size_by_bit 20 \ +--graph_api_args "$SRC_PATH" + +cd - || exit diff --git a/model_zoo/gcn/scripts/run_train.sh b/model_zoo/gcn/scripts/run_train.sh new file mode 100755 index 0000000000..46dee49b0d --- /dev/null +++ b/model_zoo/gcn/scripts/run_train.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ] +then + echo "Usage: sh run_train.sh [DATASET_NAME]" +exit 1 +fi + +DATASET_NAME=$1 +echo $DATASET_NAME + +ulimit -u unlimited +export DEVICE_NUM=1 +export RANK_SIZE=$DEVICE_NUM +export DEVICE_ID=0 +export RANK_ID=0 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train || exit +env > env.log +echo "start training for device $DEVICE_ID" + + +if [ $DATASET_NAME == cora ] +then + python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=140 &> log & +fi + +if [ $DATASET_NAME == citeseer ] +then + python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=120 &> log & +fi +cd .. + diff --git a/model_zoo/gcn/src/config.py b/model_zoo/gcn/src/config.py new file mode 100644 index 0000000000..83974d706c --- /dev/null +++ b/model_zoo/gcn/src/config.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py +""" + + +class ConfigGCN(): + learning_rate = 0.01 + epochs = 200 + hidden1 = 16 + dropout = 0.5 + weight_decay = 5e-4 + early_stopping = 10 diff --git a/model_zoo/gcn/src/dataset.py b/model_zoo/gcn/src/dataset.py new file mode 100644 index 0000000000..39843b5af7 --- /dev/null +++ b/model_zoo/gcn/src/dataset.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create adjacency matrix, node features, labels, and mask for training. +""" +import numpy as np +import scipy.sparse as sp +import mindspore.dataset as ds + + +def normalize_adj(adj): + """Symmetrically normalize adjacency matrix.""" + rowsum = np.array(adj.sum(1)) + d_inv_sqrt = np.power(rowsum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() + + +def get_adj_features_labels(data_dir): + """Get adjacency matrix, node features and labels from dataset.""" + g = ds.GraphData(data_dir) + nodes = g.get_all_nodes(0) + nodes_list = nodes.tolist() + row_tensor = g.get_node_feature(nodes_list, [1, 2]) + features = row_tensor[0] + labels = row_tensor[1] + + nodes_num = labels.shape[0] + class_num = labels.max() + 1 + labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32) + + neighbor = g.get_all_neighbors(nodes_list, 0) + node_map = {node_id: index for index, node_id in enumerate(nodes_list)} + adj = np.zeros([nodes_num, nodes_num], dtype=np.float32) + for index, value in np.ndenumerate(neighbor): + # The first column of neighbor is node_id, second column to last column are neighbors of the first column. + # So we only care index[1] > 1. + # If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it. + if value >= 0 and index[1] > 0: + adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1 + adj = sp.coo_matrix(adj) + adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num) + nor_adj = normalize_adj(adj) + nor_adj = np.array(nor_adj.todense()) + return nor_adj, features, labels_onehot + + +def get_mask(total, begin, end): + """Generate mask.""" + mask = np.zeros([total]).astype(np.float32) + mask[begin:end] = 1 + return mask diff --git a/model_zoo/gcn/src/gcn.py b/model_zoo/gcn/src/gcn.py new file mode 100644 index 0000000000..8bad127ff5 --- /dev/null +++ b/model_zoo/gcn/src/gcn.py @@ -0,0 +1,220 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GCN.""" +import numpy as np +from mindspore import nn +from mindspore.common.parameter import ParameterTuple +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore.nn.layer.activation import get_activation +from src.metrics import Loss, Accuracy + + +def glorot(shape): + init_range = np.sqrt(6.0/(shape[0]+shape[1])) + initial = np.random.uniform(-init_range, init_range, shape).astype(np.float32) + return Tensor(initial) + + +class GraphConvolution(nn.Cell): + """ + GCN graph convolution layer. + + Args: + feature_in_dim (int): The input feature dimension. + feature_out_dim (int): The output feature dimension. + dropout_ratio (float): Dropout ratio for the dropout layer. Default: None. + activation (str): Activation function applied to the output of the layer, eg. 'relu'. Default: None. + + Inputs: + - **adj** (Tensor) - Tensor of shape :math:`(N, N)`. + - **input_feature** (Tensor) - Tensor of shape :math:`(N, C)`. + + Outputs: + Tensor, output tensor. + """ + + def __init__(self, + feature_in_dim, + feature_out_dim, + dropout_ratio=None, + activation=None): + super(GraphConvolution, self).__init__() + self.in_dim = feature_in_dim + self.out_dim = feature_out_dim + self.weight_init = glorot([self.out_dim, self.in_dim]) + self.fc = nn.Dense(self.in_dim, + self.out_dim, + weight_init=self.weight_init, + has_bias=False) + self.dropout_ratio = dropout_ratio + if self.dropout_ratio is not None: + self.dropout = nn.Dropout(keep_prob=1-self.dropout_ratio) + self.dropout_flag = self.dropout_ratio is not None + self.activation = get_activation(activation) + self.activation_flag = self.activation is not None + self.matmul = P.MatMul() + + def construct(self, adj, input_feature): + dropout = input_feature + if self.dropout_flag: + dropout = self.dropout(dropout) + + fc = self.fc(dropout) + output_feature = self.matmul(adj, fc) + + if self.activation_flag: + output_feature = self.activation(output_feature) + return output_feature + + +class GCN(nn.Cell): + """ + GCN architecture. + + Args: + config (ConfigGCN): Configuration for GCN. + adj (numpy.ndarray): Numbers of block in different layers. + feature (numpy.ndarray): Input channel in each layer. + output_dim (int): The number of output channels, equal to classes num. + """ + + def __init__(self, config, adj, feature, output_dim): + super(GCN, self).__init__() + self.adj = Tensor(adj) + self.feature = Tensor(feature) + input_dim = feature.shape[1] + self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout) + self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None) + + def construct(self): + output0 = self.layer0(self.adj, self.feature) + output1 = self.layer1(self.adj, output0) + return output1 + + +class LossAccuracyWrapper(nn.Cell): + """ + Wraps the GCN model with loss and accuracy cell. + + Args: + network (Cell): GCN network. + label (numpy.ndarray): Dataset labels. + mask (numpy.ndarray): Mask for training, evaluation or test. + weight_decay (float): Weight decay parameter for weight of the first convolution layer. + """ + + def __init__(self, network, label, mask, weight_decay): + super(LossAccuracyWrapper, self).__init__() + self.network = network + self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) + self.accuracy = Accuracy(label, mask) + + def construct(self): + preds = self.network() + loss = self.loss(preds) + accuracy = self.accuracy(preds) + return loss, accuracy + + +class LossWrapper(nn.Cell): + """ + Wraps the GCN model with loss. + + Args: + network (Cell): GCN network. + label (numpy.ndarray): Dataset labels. + mask (numpy.ndarray): Mask for training. + weight_decay (float): Weight decay parameter for weight of the first convolution layer. + """ + + def __init__(self, network, label, mask, weight_decay): + super(LossWrapper, self).__init__() + self.network = network + self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) + + def construct(self): + preds = self.network() + loss = self.loss(preds) + return loss + + +class TrainOneStepCell(nn.Cell): + r""" + Network training package class. + + Wraps the network with an optimizer. The resulting Cell be trained without inputs. + Backward graph will be created in the construct function to do parameter updating. Different + parallel modes are available to run the training. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. + + Outputs: + Tensor, a scalar Tensor with shape :math:`()`. + + Examples: + >>> net = Net() + >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() + >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> loss_net = nn.WithLossCell(net, loss_fn) + >>> train_net = nn.TrainOneStepCell(loss_net, optim) + """ + + def __init__(self, network, optimizer, sens=1.0): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + + def construct(self): + weights = self.weights + loss = self.network() + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(sens) + return F.depend(loss, self.optimizer(grads)) + + +class TrainNetWrapper(nn.Cell): + """ + Wraps the GCN model with optimizer. + + Args: + network (Cell): GCN network. + label (numpy.ndarray): Dataset labels. + mask (numpy.ndarray): Mask for training, evaluation or test. + config (ConfigGCN): Configuration for GCN. + """ + + def __init__(self, network, label, mask, config): + super(TrainNetWrapper, self).__init__(auto_prefix=True) + self.network = network + loss_net = LossWrapper(network, label, mask, config.weight_decay) + optimizer = nn.Adam(loss_net.trainable_params(), + learning_rate=config.learning_rate) + self.loss_train_net = TrainOneStepCell(loss_net, optimizer) + self.accuracy = Accuracy(label, mask) + + def construct(self): + loss = self.loss_train_net() + accuracy = self.accuracy(self.network()) + return loss, accuracy diff --git a/model_zoo/gcn/src/metrics.py b/model_zoo/gcn/src/metrics.py new file mode 100644 index 0000000000..5930956776 --- /dev/null +++ b/model_zoo/gcn/src/metrics.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Loss and accuracy.""" +from mindspore import nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + + +class Loss(nn.Cell): + """Softmax cross-entropy loss with masking.""" + def __init__(self, label, mask, weight_decay, param): + super(Loss, self).__init__() + self.label = Tensor(label) + self.mask = Tensor(mask) + self.loss = P.SoftmaxCrossEntropyWithLogits() + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + self.mean = P.ReduceMean() + self.cast = P.Cast() + self.l2_loss = P.L2Loss() + self.reduce_sum = P.ReduceSum() + self.weight_decay = weight_decay + self.param = param + + def construct(self, preds): + param = self.l2_loss(self.param) + loss = self.weight_decay * param + preds = self.cast(preds, mstype.float32) + loss = loss + self.loss(preds, self.label)[0] + mask = self.cast(self.mask, mstype.float32) + mask_reduce = self.mean(mask) + mask = mask / mask_reduce + loss = loss * mask + loss = self.mean(loss) + return loss + + +class Accuracy(nn.Cell): + """Accuracy with masking.""" + def __init__(self, label, mask): + super(Accuracy, self).__init__() + self.label = Tensor(label) + self.mask = Tensor(mask) + self.equal = P.Equal() + self.argmax = P.Argmax() + self.cast = P.Cast() + self.mean = P.ReduceMean() + + def construct(self, preds): + preds = self.cast(preds, mstype.float32) + correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label)) + accuracy_all = self.cast(correct_prediction, mstype.float32) + mask = self.cast(self.mask, mstype.float32) + mask_reduce = self.mean(mask) + mask = mask / mask_reduce + accuracy_all *= mask + return self.mean(accuracy_all) diff --git a/model_zoo/gcn/train.py b/model_zoo/gcn/train.py new file mode 100644 index 0000000000..c3502ab3da --- /dev/null +++ b/model_zoo/gcn/train.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +GCN training script. +""" + +import time +import argparse + +import numpy as np +from mindspore import context + +from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper +from src.config import ConfigGCN +from src.dataset import get_adj_features_labels, get_mask + + +def train(): + """Train model.""" + parser = argparse.ArgumentParser(description='GCN') + parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory') + parser.add_argument('--seed', type=int, default=123, help='Random seed') + parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training') + parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation') + parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') + args_opt = parser.parse_args() + + np.random.seed(args_opt.seed) + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", save_graphs=False) + config = ConfigGCN() + adj, feature, label = get_adj_features_labels(args_opt.data_dir) + + nodes_num = label.shape[0] + train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num) + eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num) + test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num) + + class_num = label.shape[1] + gcn_net = GCN(config, adj, feature, class_num) + gcn_net.add_flags_recursive(fp16=True) + + eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay) + test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay) + train_net = TrainNetWrapper(gcn_net, label, train_mask, config) + + loss_list = [] + for epoch in range(config.epochs): + t = time.time() + + train_net.set_train() + train_result = train_net() + train_loss = train_result[0].asnumpy() + train_accuracy = train_result[1].asnumpy() + + eval_net.set_train(False) + eval_result = eval_net() + eval_loss = eval_result[0].asnumpy() + eval_accuracy = eval_result[1].asnumpy() + + loss_list.append(eval_loss) + print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss), + "train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss), + "val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t)) + + if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]): + print("Early stopping...") + break + + t_test = time.time() + test_net.set_train(False) + test_result = test_net() + test_loss = test_result[0].asnumpy() + test_accuracy = test_result[1].asnumpy() + print("Test set results:", "loss=", "{:.5f}".format(test_loss), + "accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test)) + + +if __name__ == '__main__': + train() diff --git a/tests/st/gnn/gcn/src/config.py b/tests/st/gnn/gcn/src/config.py index 053e48cad9..4276e54262 100644 --- a/tests/st/gnn/gcn/src/config.py +++ b/tests/st/gnn/gcn/src/config.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ + class ConfigGCN(): learning_rate = 0.01 epochs = 200 diff --git a/tests/st/gnn/gcn/test_gcn.py b/tests/st/gnn/gcn/test_gcn.py index 21304a7b8b..5f01c455e7 100644 --- a/tests/st/gnn/gcn/test_gcn.py +++ b/tests/st/gnn/gcn/test_gcn.py @@ -58,10 +58,12 @@ def test_gcn(): for epoch in range(config.epochs): t = time.time() + train_net.set_train() train_result = train_net() train_loss = train_result[0].asnumpy() train_accuracy = train_result[1].asnumpy() + eval_net.set_train(False) eval_result = eval_net() eval_loss = eval_result[0].asnumpy() eval_accuracy = eval_result[1].asnumpy() @@ -75,6 +77,7 @@ def test_gcn(): print("Early stopping...") break + test_net.set_train(False) test_result = test_net() test_loss = test_result[0].asnumpy() test_accuracy = test_result[1].asnumpy()