From e801d489064aa939a396150197e01d75b72f342d Mon Sep 17 00:00:00 2001 From: chentingting Date: Tue, 12 May 2020 18:03:46 +0800 Subject: [PATCH] add gcn training scripts --- tests/st/gnn/gcn/__init__.py | 0 tests/st/gnn/gcn/src/__init__.py | 0 tests/st/gnn/gcn/src/config.py | 22 +++++ tests/st/gnn/gcn/src/dataset.py | 60 ++++++++++++ tests/st/gnn/gcn/src/gcn.py | 163 +++++++++++++++++++++++++++++++ tests/st/gnn/gcn/src/metrics.py | 68 +++++++++++++ tests/st/gnn/gcn/test_gcn.py | 83 ++++++++++++++++ 7 files changed, 396 insertions(+) create mode 100644 tests/st/gnn/gcn/__init__.py create mode 100644 tests/st/gnn/gcn/src/__init__.py create mode 100644 tests/st/gnn/gcn/src/config.py create mode 100644 tests/st/gnn/gcn/src/dataset.py create mode 100644 tests/st/gnn/gcn/src/gcn.py create mode 100644 tests/st/gnn/gcn/src/metrics.py create mode 100644 tests/st/gnn/gcn/test_gcn.py diff --git a/tests/st/gnn/gcn/__init__.py b/tests/st/gnn/gcn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/st/gnn/gcn/src/__init__.py b/tests/st/gnn/gcn/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/st/gnn/gcn/src/config.py b/tests/st/gnn/gcn/src/config.py new file mode 100644 index 0000000000..053e48cad9 --- /dev/null +++ b/tests/st/gnn/gcn/src/config.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +class ConfigGCN(): + learning_rate = 0.01 + epochs = 200 + hidden1 = 16 + dropout = 0.0 + weight_decay = 5e-4 + early_stopping = 10 diff --git a/tests/st/gnn/gcn/src/dataset.py b/tests/st/gnn/gcn/src/dataset.py new file mode 100644 index 0000000000..cdf1481120 --- /dev/null +++ b/tests/st/gnn/gcn/src/dataset.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import scipy.sparse as sp +import mindspore.dataset as ds + + +def normalize_adj(adj): + rowsum = np.array(adj.sum(1)) + d_inv_sqrt = np.power(rowsum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() + + +def get_adj_features_labels(data_dir): + g = ds.GraphData(data_dir) + nodes = g.get_all_nodes(0) + nodes_list = nodes.tolist() + row_tensor = g.get_node_feature(nodes_list, [1, 2]) + features = row_tensor[0] + labels = row_tensor[1] + + nodes_num = labels.shape[0] + class_num = labels.max() + 1 + labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32) + + neighbor = g.get_all_neighbors(nodes_list, 0) + node_map = {node_id: index for index, node_id in enumerate(nodes_list)} + adj = np.zeros([nodes_num, nodes_num], dtype=np.float32) + for index, value in np.ndenumerate(neighbor): + # The first column of neighbor is node_id, second column to last column are neighbors of the first column. + # So we only care index[1] > 1. + # If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it. + if value >= 0 and index[1] > 0: + adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1 + adj = sp.coo_matrix(adj) + adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num) + nor_adj = normalize_adj(adj) + nor_adj = np.array(nor_adj.todense()) + return nor_adj, features, labels_onehot + + +def get_mask(total, begin, end): + mask = np.zeros([total]).astype(np.float32) + mask[begin:end] = 1 + return mask diff --git a/tests/st/gnn/gcn/src/gcn.py b/tests/st/gnn/gcn/src/gcn.py new file mode 100644 index 0000000000..a364c1bafa --- /dev/null +++ b/tests/st/gnn/gcn/src/gcn.py @@ -0,0 +1,163 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from mindspore import nn +from mindspore.common.parameter import ParameterTuple +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore.nn.layer.activation import get_activation +from src.metrics import Loss, Accuracy + + +def glorot(shape): + init_range = np.sqrt(6.0/(shape[0]+shape[1])) + initial = np.random.uniform(-init_range, init_range, shape).astype(np.float32) + return Tensor(initial) + + +class GraphConvolution(nn.Cell): + def __init__(self, + feature_in_dim, + feature_out_dim, + dropout_ratio=None, + activation=None): + super(GraphConvolution, self).__init__() + self.in_dim = feature_in_dim + self.out_dim = feature_out_dim + self.weight_init = glorot([self.out_dim, self.in_dim]) + self.fc = nn.Dense(self.in_dim, + self.out_dim, + weight_init=self.weight_init, + has_bias=False) + self.dropout_ratio = dropout_ratio + if self.dropout_ratio is not None: + self.dropout = nn.Dropout(keep_prob=1-self.dropout_ratio) + self.dropout_flag = self.dropout_ratio is not None + self.activation = get_activation(activation) + self.activation_flag = self.activation is not None + self.matmul = P.MatMul() + + def construct(self, adj, input_feature): + dropout = input_feature + if self.dropout_flag: + dropout = self.dropout(dropout) + + fc = self.fc(dropout) + output_feature = self.matmul(adj, fc) + + if self.activation_flag: + output_feature = self.activation(output_feature) + return output_feature + + +class GCN(nn.Cell): + def __init__(self, config, adj, feature, output_dim): + super(GCN, self).__init__() + self.adj = Tensor(adj) + self.feature = Tensor(feature) + input_dim = feature.shape[1] + self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout) + self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None) + + def construct(self): + output0 = self.layer0(self.adj, self.feature) + output1 = self.layer1(self.adj, output0) + return output1 + + +class LossAccuracyWrapper(nn.Cell): + def __init__(self, network, label, mask, weight_decay): + super(LossAccuracyWrapper, self).__init__() + self.network = network + self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) + self.accuracy = Accuracy(label, mask) + + def construct(self): + preds = self.network() + loss = self.loss(preds) + accuracy = self.accuracy(preds) + return loss, accuracy + + +class LossWrapper(nn.Cell): + def __init__(self, network, label, mask, weight_decay): + super(LossWrapper, self).__init__() + self.network = network + self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) + + def construct(self): + preds = self.network() + loss = self.loss(preds) + return loss + + +class TrainOneStepCell(nn.Cell): + r""" + Network training package class. + + Wraps the network with an optimizer. The resulting Cell be trained without inputs. + Backward graph will be created in the construct function to do parameter updating. Different + parallel modes are available to run the training. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. + + Outputs: + Tensor, a scalar Tensor with shape :math:`()`. + + Examples: + >>> net = Net() + >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() + >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> loss_net = nn.WithLossCell(net, loss_fn) + >>> train_net = nn.TrainOneStepCell(loss_net, optim) + """ + + def __init__(self, network, optimizer, sens=1.0): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + + def construct(self): + weights = self.weights + loss = self.network() + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(sens) + return F.depend(loss, self.optimizer(grads)) + + +class TrainNetWrapper(nn.Cell): + def __init__(self, network, label, mask, config): + super(TrainNetWrapper, self).__init__(auto_prefix=True) + self.network = network + loss_net = LossWrapper(network, label, mask, config.weight_decay) + optimizer = nn.Adam(loss_net.trainable_params(), + learning_rate=config.learning_rate) + self.loss_train_net = TrainOneStepCell(loss_net, optimizer) + self.accuracy = Accuracy(label, mask) + + def construct(self): + loss = self.loss_train_net() + accuracy = self.accuracy(self.network()) + return loss, accuracy diff --git a/tests/st/gnn/gcn/src/metrics.py b/tests/st/gnn/gcn/src/metrics.py new file mode 100644 index 0000000000..18ce1bd186 --- /dev/null +++ b/tests/st/gnn/gcn/src/metrics.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore import nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + + +class Loss(nn.Cell): + def __init__(self, label, mask, weight_decay, param): + super(Loss, self).__init__() + self.label = Tensor(label) + self.mask = Tensor(mask) + self.loss = P.SoftmaxCrossEntropyWithLogits() + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + self.mean = P.ReduceMean() + self.cast = P.Cast() + self.l2_loss = P.L2Loss() + self.reduce_sum = P.ReduceSum() + self.weight_decay = weight_decay + self.param = param + + def construct(self, preds): + param = self.l2_loss(self.param) + loss = self.weight_decay * param + preds = self.cast(preds, mstype.float32) + loss = loss + self.loss(preds, self.label)[0] + mask = self.cast(self.mask, mstype.float32) + mask_reduce = self.mean(mask) + mask = mask / mask_reduce + loss = loss * mask + loss = self.mean(loss) + return loss + + +class Accuracy(nn.Cell): + def __init__(self, label, mask): + super(Accuracy, self).__init__() + self.label = Tensor(label) + self.mask = Tensor(mask) + self.equal = P.Equal() + self.argmax = P.Argmax() + self.cast = P.Cast() + self.mean = P.ReduceMean() + + def construct(self, preds): + preds = self.cast(preds, mstype.float32) + correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label)) + accuracy_all = self.cast(correct_prediction, mstype.float32) + mask = self.cast(self.mask, mstype.float32) + mask_reduce = self.mean(mask) + mask = mask / mask_reduce + accuracy_all *= mask + return self.mean(accuracy_all) diff --git a/tests/st/gnn/gcn/test_gcn.py b/tests/st/gnn/gcn/test_gcn.py new file mode 100644 index 0000000000..21304a7b8b --- /dev/null +++ b/tests/st/gnn/gcn/test_gcn.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import time +import pytest +import numpy as np +from mindspore import context +from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper +from src.config import ConfigGCN +from src.dataset import get_adj_features_labels, get_mask + + +DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr' +TRAIN_NODE_NUM = 140 +EVAL_NODE_NUM = 500 +TEST_NODE_NUM = 1000 +SEED = 20 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_gcn(): + print("test_gcn begin") + np.random.seed(SEED) + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", save_graphs=True) + config = ConfigGCN() + adj, feature, label = get_adj_features_labels(DATA_DIR) + + nodes_num = label.shape[0] + train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM) + eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM) + test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num) + + class_num = label.shape[1] + gcn_net = GCN(config, adj, feature, class_num) + gcn_net.add_flags_recursive(fp16=True) + + eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay) + test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay) + train_net = TrainNetWrapper(gcn_net, label, train_mask, config) + + loss_list = [] + for epoch in range(config.epochs): + t = time.time() + + train_result = train_net() + train_loss = train_result[0].asnumpy() + train_accuracy = train_result[1].asnumpy() + + eval_result = eval_net() + eval_loss = eval_result[0].asnumpy() + eval_accuracy = eval_result[1].asnumpy() + + loss_list.append(eval_loss) + print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss), + "train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss), + "val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t)) + + if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]): + print("Early stopping...") + break + + test_result = test_net() + test_loss = test_result[0].asnumpy() + test_accuracy = test_result[1].asnumpy() + print("Test set results:", "loss=", "{:.5f}".format(test_loss), + "accuracy=", "{:.5f}".format(test_accuracy)) + assert test_accuracy > 0.812