diff --git a/example/graph_to_mindrecord/README.md b/example/graph_to_mindrecord/README.md index cc6f6a1c70..df7ab33444 100644 --- a/example/graph_to_mindrecord/README.md +++ b/example/graph_to_mindrecord/README.md @@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t 1. Download and prepare the Cora dataset as required. - > [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora) - - 2. Edit write_cora.sh and modify the parameters ``` --mindrecord_file: output MindRecord file. diff --git a/example/graph_to_mindrecord/citeseer/mr_api.py b/example/graph_to_mindrecord/citeseer/mr_api.py index 8b1f424b0a..69bc442f4e 100644 --- a/example/graph_to_mindrecord/citeseer/mr_api.py +++ b/example/graph_to_mindrecord/citeseer/mr_api.py @@ -15,29 +15,26 @@ """ User-defined API for MindRecord GNN writer. """ -import csv import os +import pickle as pkl import numpy as np import scipy.sparse as sp # parse args from command line parameter 'graph_api_args' # args delimiter is ':' args = os.environ['graph_api_args'].split(':') -CITESEER_CONTENT_FILE = args[0] -CITESEER_CITES_FILE = args[1] -CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord" -CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord" - -node_id_map = {} +CITESEER_PATH = args[0] +dataset_str = 'citeseer' # profile: (num_features, feature_data_types, feature_shapes) -node_profile = (2, ["float32", "int64"], [[-1], [-1]]) +node_profile = (2, ["float32", "int32"], [[-1], [-1]]) edge_profile = (0, [], []) +node_ids = [] + def _normalize_citeseer_features(features): - features = np.array(features) row_sum = np.array(features.sum(1)) r_inv = np.power(row_sum * 1.0, -1).flatten() r_inv[np.isinf(r_inv)] = 0. @@ -46,6 +43,14 @@ def _normalize_citeseer_features(features): return features +def _parse_index_file(filename): + """Parse index file.""" + index = [] + for line in open(filename): + index.append(int(line.strip())) + return index + + def yield_nodes(task_id=0): """ Generate node data @@ -54,29 +59,46 @@ def yield_nodes(task_id=0): data (dict): data row which is dict. """ print("Node task is {}".format(task_id)) - label_types = {} - label_size = 0 - node_num = 0 - with open(CITESEER_CONTENT_FILE) as content_file: - content_reader = csv.reader(content_file, delimiter='\t') - line_count = 0 - for row in content_reader: - if not row[-1] in label_types: - label_types[row[-1]] = label_size - label_size += 1 - if not row[0] in node_id_map: - node_id_map[row[0]] = node_num - node_num += 1 - raw_features = [[int(x) for x in row[1:-1]]] - node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features), - 'feature_2': [label_types[row[-1]]]} - yield node - line_count += 1 + names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] + objects = [] + for name in names: + with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f: + objects.append(pkl.load(f, encoding='latin1')) + x, y, tx, ty, allx, ally = tuple(objects) + test_idx_reorder = _parse_index_file( + "{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str)) + test_idx_range = np.sort(test_idx_reorder) + + tx = _normalize_citeseer_features(tx) + allx = _normalize_citeseer_features(allx) + + # Fix citeseer dataset (there are some isolated nodes in the graph) + # Find isolated nodes, add them as zero-vecs into the right position + test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) + tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) + tx_extended[test_idx_range-min(test_idx_range), :] = tx + tx = tx_extended + ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) + ty_extended[test_idx_range-min(test_idx_range), :] = ty + ty = ty_extended + + features = sp.vstack((allx, tx)).tolil() + features[test_idx_reorder, :] = features[test_idx_range, :] + features = features.A + + labels = np.vstack((ally, ty)) + labels[test_idx_reorder, :] = labels[test_idx_range, :] + + line_count = 0 + for i, label in enumerate(labels): + if not 1 in label.tolist(): + continue + node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), + 'feature_2': label.tolist().index(1)} + line_count += 1 + node_ids.append(i) + yield node print('Processed {} lines for nodes.'.format(line_count)) - # print('label types {}.'.format(label_types)) - with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f: - for k in label_types: - print(k + ',' + str(label_types[k]), file=f) def yield_edges(task_id=0): @@ -87,23 +109,20 @@ def yield_edges(task_id=0): data (dict): data row which is dict. """ print("Edge task is {}".format(task_id)) - # print(map_string_int) - with open(CITESEER_CITES_FILE) as cites_file: - cites_reader = csv.reader(cites_file, delimiter='\t') + with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: + graph = pkl.load(f, encoding='latin1') line_count = 0 - for row in cites_reader: - if not row[0] in node_id_map: - print('Source node {} does not exist.'.format(row[0])) - continue - if not row[1] in node_id_map: - print('Destination node {} does not exist.'.format(row[1])) - continue - line_count += 1 - edge = {'id': line_count, - 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} - yield edge - - with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f: - for k in node_id_map: - print(k + ',' + str(node_id_map[k]), file=f) + for i in graph: + for dst_id in graph[i]: + if not i in node_ids: + print('Source node {} does not exist.'.format(i)) + continue + if not dst_id in node_ids: + print('Destination node {} does not exist.'.format( + dst_id)) + continue + edge = {'id': line_count, + 'src_id': i, 'dst_id': dst_id, 'type': 0} + line_count += 1 + yield edge print('Processed {} lines for edges.'.format(line_count)) diff --git a/example/graph_to_mindrecord/cora/mr_api.py b/example/graph_to_mindrecord/cora/mr_api.py index 0963fd78f7..aeeb0e04de 100644 --- a/example/graph_to_mindrecord/cora/mr_api.py +++ b/example/graph_to_mindrecord/cora/mr_api.py @@ -15,29 +15,24 @@ """ User-defined API for MindRecord GNN writer. """ -import csv import os +import pickle as pkl import numpy as np import scipy.sparse as sp # parse args from command line parameter 'graph_api_args' # args delimiter is ':' args = os.environ['graph_api_args'].split(':') -CORA_CONTENT_FILE = args[0] -CORA_CITES_FILE = args[1] -CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord" -CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord" - -node_id_map = {} +CORA_PATH = args[0] +dataset_str = 'cora' # profile: (num_features, feature_data_types, feature_shapes) -node_profile = (2, ["float32", "int64"], [[-1], [-1]]) +node_profile = (2, ["float32", "int32"], [[-1], [-1]]) edge_profile = (0, [], []) def _normalize_cora_features(features): - features = np.array(features) row_sum = np.array(features.sum(1)) r_inv = np.power(row_sum * 1.0, -1).flatten() r_inv[np.isinf(r_inv)] = 0. @@ -46,6 +41,14 @@ def _normalize_cora_features(features): return features +def _parse_index_file(filename): + """Parse index file.""" + index = [] + for line in open(filename): + index.append(int(line.strip())) + return index + + def yield_nodes(task_id=0): """ Generate node data @@ -54,32 +57,32 @@ def yield_nodes(task_id=0): data (dict): data row which is dict. """ print("Node task is {}".format(task_id)) - label_types = {} - label_size = 0 - node_num = 0 - with open(CORA_CONTENT_FILE) as content_file: - content_reader = csv.reader(content_file, delimiter=',') - line_count = 0 - for row in content_reader: - if line_count == 0: - line_count += 1 - continue - if not row[0] in node_id_map: - node_id_map[row[0]] = node_num - node_num += 1 - if not row[-1] in label_types: - label_types[row[-1]] = label_size - label_size += 1 - raw_features = [[int(x) for x in row[1:-1]]] - node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features), - 'feature_2': [label_types[row[-1]]]} - yield node - line_count += 1 + + names = ['tx', 'ty', 'allx', 'ally'] + objects = [] + for name in names: + with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f: + objects.append(pkl.load(f, encoding='latin1')) + tx, ty, allx, ally = tuple(objects) + test_idx_reorder = _parse_index_file( + "{}/ind.{}.test.index".format(CORA_PATH, dataset_str)) + test_idx_range = np.sort(test_idx_reorder) + + features = sp.vstack((allx, tx)).tolil() + features[test_idx_reorder, :] = features[test_idx_range, :] + features = _normalize_cora_features(features) + features = features.A + + labels = np.vstack((ally, ty)) + labels[test_idx_reorder, :] = labels[test_idx_range, :] + + line_count = 0 + for i, label in enumerate(labels): + node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), + 'feature_2': label.tolist().index(1)} + line_count += 1 + yield node print('Processed {} lines for nodes.'.format(line_count)) - print('label types {}.'.format(label_types)) - with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f: - for k in label_types: - print(k + ',' + str(label_types[k]), file=f) def yield_edges(task_id=0): @@ -90,24 +93,13 @@ def yield_edges(task_id=0): data (dict): data row which is dict. """ print("Edge task is {}".format(task_id)) - with open(CORA_CITES_FILE) as cites_file: - cites_reader = csv.reader(cites_file, delimiter=',') + with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f: + graph = pkl.load(f, encoding='latin1') line_count = 0 - for row in cites_reader: - if line_count == 0: + for i in graph: + for dst_id in graph[i]: + edge = {'id': line_count, + 'src_id': i, 'dst_id': dst_id, 'type': 0} line_count += 1 - continue - if not row[0] in node_id_map: - print('Source node {} does not exist.'.format(row[0])) - continue - if not row[1] in node_id_map: - print('Destination node {} does not exist.'.format(row[1])) - continue - edge = {'id': line_count, - 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} - yield edge - line_count += 1 + yield edge print('Processed {} lines for edges.'.format(line_count)) - with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f: - for k in node_id_map: - print(k + ',' + str(node_id_map[k]), file=f) diff --git a/example/graph_to_mindrecord/write_citeseer.sh b/example/graph_to_mindrecord/write_citeseer.sh index 33235372fa..523b2b8850 100644 --- a/example/graph_to_mindrecord/write_citeseer.sh +++ b/example/graph_to_mindrecord/write_citeseer.sh @@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \ --mindrecord_partitions 1 \ --mindrecord_header_size_by_bit 18 \ --mindrecord_page_size_by_bit 20 \ ---graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites" +--graph_api_args "$SRC_PATH" diff --git a/example/graph_to_mindrecord/write_cora.sh b/example/graph_to_mindrecord/write_cora.sh index 84ccf34f5e..fd1b6fc92a 100644 --- a/example/graph_to_mindrecord/write_cora.sh +++ b/example/graph_to_mindrecord/write_cora.sh @@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \ --mindrecord_partitions 1 \ --mindrecord_header_size_by_bit 18 \ --mindrecord_page_size_by_bit 20 \ ---graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv" +--graph_api_args "$SRC_PATH" diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index b1734eaa2b..308cd03ac4 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -527,10 +527,22 @@ void bindGraphData(py::module *m) { THROW_IF_ERROR(g_out->Init()); return g_out; })) - .def("get_nodes", - [](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) { + .def("get_all_nodes", + [](gnn::Graph &g, gnn::NodeType node_type) { std::shared_ptr out; - THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out)); + THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); + return out; + }) + .def("get_all_edges", + [](gnn::Graph &g, gnn::EdgeType edge_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); + return out; + }) + .def("get_nodes_from_edges", + [](gnn::Graph &g, std::vector edge_list) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); return out; }) .def("get_all_neighbors", @@ -539,12 +551,31 @@ void bindGraphData(py::module *m) { THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); return out; }) + .def("get_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, + std::vector neighbor_types) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); + return out; + }) + .def("get_neg_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, + gnn::NodeType neg_neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); + return out; + }) .def("get_node_feature", [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { TensorRow out; THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); return out; - }); + }) + .def("graph_info", [](gnn::Graph &g) { + py::dict out; + THROW_IF_ERROR(g.GraphInfo(&out)); + return out; + }); } // This is where we externalize the C logic as python modules diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc index 74e7b85153..2ac3f3f5bd 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.cc @@ -17,29 +17,30 @@ #include #include +#include #include #include #include "dataset/core/tensor_shape.h" +#include "dataset/util/random.h" namespace mindspore { namespace dataset { namespace gnn { -Graph::Graph(std::string dataset_file, int32_t num_workers) : dataset_file_(dataset_file), num_workers_(num_workers) { +Graph::Graph(std::string dataset_file, int32_t num_workers) + : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) { + rnd_.seed(GetSeed()); MS_LOG(INFO) << "num_workers:" << num_workers; } -Status Graph::GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr *out) { +Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr *out) { auto itr = node_type_map_.find(node_type); if (itr == node_type_map_.end()) { std::string err_msg = "Invalid node type:" + std::to_string(node_type); RETURN_STATUS_UNEXPECTED(err_msg); } else { - if (node_num == -1) { - RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); - } else { - } + RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); } return Status::OK(); } @@ -59,9 +60,9 @@ Status Graph::CreateTensorByVector(const std::vector> &data, Data RETURN_IF_NOT_OK(Tensor::CreateTensor( &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); T *ptr = reinterpret_cast(tensor->GetMutableBuffer()); - for (auto id_m : data) { + for (const auto &id_m : data) { CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); - for (auto id_n : id_m) { + for (const auto &id_n : id_m) { *ptr = id_n; ptr++; } @@ -89,7 +90,38 @@ Status Graph::ComplementVector(std::vector> *data, size_t max_siz return Status::OK(); } -Status Graph::GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr *out) { return Status::OK(); } +Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { + auto itr = edge_type_map_.find(edge_type); + if (itr == edge_type_map_.end()) { + std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); + } + return Status::OK(); +} + +Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { + if (edge_list.empty()) { + RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); + } + + std::vector> node_list; + node_list.reserve(edge_list.size()); + for (const auto &edge_id : edge_list) { + auto itr = edge_id_map_.find(edge_id); + if (itr == edge_id_map_.end()) { + std::string err_msg = "Invalid edge id:" + std::to_string(edge_id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + std::pair, std::shared_ptr> nodes; + RETURN_IF_NOT_OK(itr->second->GetNode(&nodes)); + node_list.push_back({nodes.first->id(), nodes.second->id()}); + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(node_list, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, std::shared_ptr *out) { @@ -105,14 +137,10 @@ Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType size_t max_neighbor_num = 0; neighbors.resize(node_list.size()); for (size_t i = 0; i < node_list.size(); ++i) { - auto itr = node_id_map_.find(node_list[i]); - if (itr != node_id_map_.end()) { - RETURN_IF_NOT_OK(itr->second->GetNeighbors(neighbor_type, -1, &neighbors[i])); - max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); - } else { - std::string err_msg = "Invalid node id:" + std::to_string(node_list[i]); - RETURN_STATUS_UNEXPECTED(err_msg); - } + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); + RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); + max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); } RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); @@ -121,13 +149,94 @@ Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType return Status::OK(); } -Status Graph::GetSampledNeighbor(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { +Status Graph::GetSampledNeighbors(const std::vector &node_list, + const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), + "The sizes of neighbor_nums and neighbor_types are inconsistent."); + std::vector> neighbors_vec(node_list.size()); + for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { + neighbors_vec[node_idx].emplace_back(node_list[node_idx]); + std::vector input_list = {node_list[node_idx]}; + for (size_t i = 0; i < neighbor_nums.size(); ++i) { + std::vector neighbors; + neighbors.reserve(input_list.size() * neighbor_nums[i]); + for (const auto &node_id : input_list) { + if (node_id == kDefaultNodeId) { + for (int32_t j = 0; j < neighbor_nums[i]; ++j) { + neighbors.emplace_back(kDefaultNodeId); + } + } else { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); + std::vector out; + RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); + neighbors.insert(neighbors.end(), out.begin(), out.end()); + } + } + neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end()); + input_list = std::move(neighbors); + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); return Status::OK(); } -Status Graph::GetNegSampledNeighbor(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out) { +Status Graph::NegativeSample(const std::vector &data, const std::unordered_set &exclude_data, + int32_t samples_num, std::vector *out_samples) { + CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); + std::vector shuffled_id(data.size()); + std::iota(shuffled_id.begin(), shuffled_id.end(), 0); + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + for (const auto &index : shuffled_id) { + if (exclude_data.find(data[index]) != exclude_data.end()) { + continue; + } + out_samples->emplace_back(data[index]); + if (out_samples->size() >= samples_num) { + break; + } + } + return Status::OK(); +} + +Status Graph::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + std::vector> neighbors_vec; + neighbors_vec.resize(node_list.size()); + for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); + std::vector neighbors; + RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); + std::unordered_set exclude_node; + std::transform(neighbors.begin(), neighbors.end(), + std::insert_iterator>(exclude_node, exclude_node.begin()), + [](const NodeIdType node) { return node; }); + auto itr = node_type_map_.find(neg_neighbor_type); + if (itr == node_type_map_.end()) { + std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + neighbors_vec[node_idx].emplace_back(node->id()); + if (itr->second.size() > exclude_node.size()) { + while (neighbors_vec[node_idx].size() < samples_num + 1) { + RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(), + &neighbors_vec[node_idx])); + } + } else { + MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() + << " neg_neighbor_type:" << neg_neighbor_type; + // If there are no negative neighbors, they are filled with kDefaultNodeId + for (int32_t i = 0; i < samples_num; ++i) { + neighbors_vec[node_idx].emplace_back(kDefaultNodeId); + } + } + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); return Status::OK(); } @@ -154,7 +263,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::ve } CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty"); TensorRow tensors; - for (auto f_type : feature_types) { + for (const auto &f_type : feature_types) { std::shared_ptr default_feature; // If no feature can be obtained, fill in the default value RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); @@ -169,18 +278,14 @@ Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::ve dsize_t index = 0; for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { - auto itr = node_id_map_.find(*node_itr); std::shared_ptr feature; - if (itr != node_id_map_.end()) { - if (!itr->second->GetFeatures(f_type, &feature).IsOk()) { - feature = default_feature; - } + if (*node_itr == kDefaultNodeId) { + feature = default_feature; } else { - if (*node_itr == kDefaultNodeId) { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); + if (!node->GetFeatures(f_type, &feature).IsOk()) { feature = default_feature; - } else { - std::string err_msg = "Invalid node id:" + std::to_string(*node_itr); - RETURN_STATUS_UNEXPECTED(err_msg); } } RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); @@ -209,35 +314,54 @@ Status Graph::Init() { return Status::OK(); } -Status Graph::GetMetaInfo(std::vector *node_info, std::vector *edge_info) { - node_info->reserve(node_type_map_.size()); - for (auto node : node_type_map_) { - NodeMetaInfo n_info; - n_info.type = node.first; - n_info.num = node.second.size(); - auto itr = node_feature_map_.find(node.first); - if (itr != node_feature_map_.end()) { - for (auto f_type : itr->second) { - n_info.feature_type.push_back(f_type); - } - std::sort(n_info.feature_type.begin(), n_info.feature_type.end()); +Status Graph::GetMetaInfo(MetaInfo *meta_info) { + meta_info->node_type.resize(node_type_map_.size()); + std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), + [](auto itr) { return itr.first; }); + std::sort(meta_info->node_type.begin(), meta_info->node_type.end()); + + meta_info->edge_type.resize(edge_type_map_.size()); + std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(), + [](auto itr) { return itr.first; }); + std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end()); + + for (const auto &node : node_type_map_) { + meta_info->node_num[node.first] = node.second.size(); + } + + for (const auto &edge : edge_type_map_) { + meta_info->edge_num[edge.first] = edge.second.size(); + } + + for (const auto &node_feature : node_feature_map_) { + for (auto type : node_feature.second) { + meta_info->node_feature_type.emplace_back(type); } - node_info->push_back(n_info); - } - - edge_info->reserve(edge_type_map_.size()); - for (auto edge : edge_type_map_) { - EdgeMetaInfo e_info; - e_info.type = edge.first; - e_info.num = edge.second.size(); - auto itr = edge_feature_map_.find(edge.first); - if (itr != edge_feature_map_.end()) { - for (auto f_type : itr->second) { - e_info.feature_type.push_back(f_type); - } + } + std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); + auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); + meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end()); + + for (const auto &edge_feature : edge_feature_map_) { + for (const auto &type : edge_feature.second) { + meta_info->edge_feature_type.emplace_back(type); } - edge_info->push_back(e_info); } + std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); + auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); + meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end()); + return Status::OK(); +} + +Status Graph::GraphInfo(py::dict *out) { + MetaInfo meta_info; + RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); + (*out)["node_type"] = py::cast(meta_info.node_type); + (*out)["edge_type"] = py::cast(meta_info.edge_type); + (*out)["node_num"] = py::cast(meta_info.node_num); + (*out)["edge_num"] = py::cast(meta_info.edge_num); + (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); + (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); return Status::OK(); } @@ -250,6 +374,18 @@ Status Graph::LoadNodeAndEdge() { &node_feature_map_, &edge_feature_map_, &default_feature_map_)); return Status::OK(); } + +Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { + auto itr = node_id_map_.find(id); + if (itr == node_id_map_.end()) { + std::string err_msg = "Invalid node id:" + std::to_string(id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *node = itr->second; + } + return Status::OK(); +} + } // namespace gnn } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h index 3dd6444807..694d4eea01 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -33,24 +34,13 @@ namespace mindspore { namespace dataset { namespace gnn { -struct NodeMetaInfo { - NodeType type; - NodeIdType num; - std::vector feature_type; - NodeMetaInfo() { - type = 0; - num = 0; - } -}; - -struct EdgeMetaInfo { - EdgeType type; - EdgeIdType num; - std::vector feature_type; - EdgeMetaInfo() { - type = 0; - num = 0; - } +struct MetaInfo { + std::vector node_type; + std::vector edge_type; + std::map node_num; + std::map edge_num; + std::vector node_feature_type; + std::vector edge_feature_type; }; class Graph { @@ -62,19 +52,23 @@ class Graph { ~Graph() = default; - // Get the nodes from the graph. + // Get all nodes from the graph. // @param NodeType node_type - type of node - // @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired // @param std::shared_ptr *out - Returned nodes id // @return Status - The error code return - Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr *out); + Status GetAllNodes(NodeType node_type, std::shared_ptr *out); - // Get the edges from the graph. + // Get all edges from the graph. // @param NodeType edge_type - type of edge - // @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired // @param std::shared_ptr *out - Returned edge ids // @return Status - The error code return - Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr *out); + Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out); + + // Get the node id from the edge. + // @param std::vector edge_list - List of edges + // @param std::shared_ptr *out - Returned node ids + // @return Status - The error code return + Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out); // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes @@ -86,10 +80,24 @@ class Graph { Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, std::shared_ptr *out); - Status GetSampledNeighbor(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out); - Status GetNegSampledNeighbor(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out); + // Get sampled neighbors. + // @param std::vector node_list - List of nodes + // @param std::vector neighbor_nums - Number of neighbors sampled per hop + // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::shared_ptr *out - Returned neighbor's id. + // @return Status - The error code return + Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out); + + // Get negative sampled neighbors. + // @param std::vector node_list - List of nodes + // @param NodeIdType samples_num - Number of neighbors sampled + // @param NodeType neg_neighbor_type - The type of negative neighbor. + // @param std::shared_ptr *out - Returned negative neighbor's id. + // @return Status - The error code return + Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out); + Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, float p, float q, NodeIdType default_node, std::shared_ptr *out); @@ -112,10 +120,12 @@ class Graph { TensorRow *out); // Get meta information of graph - // @param std::vector *node_info - Returned meta information of node - // @param std::vector *node_info - Returned meta information of edge + // @param MetaInfo *meta_info - Returned meta information // @return Status - The error code return - Status GetMetaInfo(std::vector *node_info, std::vector *edge_info); + Status GetMetaInfo(MetaInfo *meta_info); + + // Return meta information to python layer + Status GraphInfo(py::dict *out); Status Init(); @@ -146,8 +156,24 @@ class Graph { // @return Status - The error code return Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + // Find node object using node id + // @param NodeIdType id - + // @param std::shared_ptr *node - Returned node object + // @return Status - The error code return + Status GetNodeByNodeId(NodeIdType id, std::shared_ptr *node); + + // Negative sampling + // @param std::vector &input_data - The data set to be sampled + // @param std::unordered_set &exclude_data - Data to be excluded + // @param int32_t samples_num - + // @param std::vector *out_samples - Sampling results returned + // @return Status - The error code return + Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, + int32_t samples_num, std::vector *out_samples); + std::string dataset_file_; int32_t num_workers_; // The number of worker threads + std::mt19937 rnd_; std::unordered_map> node_type_map_; std::unordered_map> node_id_map_; diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc index 24e865dff7..e091a52faa 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc @@ -20,12 +20,13 @@ #include #include "dataset/engine/gnn/edge.h" +#include "dataset/util/random.h" namespace mindspore { namespace dataset { namespace gnn { -LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type) {} +LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { auto itr = features_.find(feature_type); @@ -38,21 +39,49 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr } } -Status LocalNode::GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector *out_neighbors) { +Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) { std::vector neighbors; auto itr = neighbor_nodes_.find(neighbor_type); if (itr != neighbor_nodes_.end()) { - if (samples_num == -1) { - // Return all neighbors - neighbors.resize(itr->second.size() + 1); - neighbors[0] = id_; - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, - [](const std::shared_ptr node) { return node->id(); }); - } else { - } + neighbors.resize(itr->second.size() + 1); + neighbors[0] = id_; + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + [](const std::shared_ptr node) { return node->id(); }); } else { - neighbors.push_back(id_); MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; + neighbors.emplace_back(id_); + } + *out_neighbors = std::move(neighbors); + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out) { + std::vector shuffled_id(neighbors.size()); + std::iota(shuffled_id.begin(), shuffled_id.end(), 0); + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + int32_t num = std::min(samples_num, static_cast(neighbors.size())); + for (int32_t i = 0; i < num; ++i) { + out->emplace_back(neighbors[shuffled_id[i]]->id()); + } + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) { + std::vector neighbors; + neighbors.reserve(samples_num); + auto itr = neighbor_nodes_.find(neighbor_type); + if (itr != neighbor_nodes_.end()) { + while (neighbors.size() < samples_num) { + RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); + } + } else { + MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; + // If there are no neighbors, they are filled with kDefaultNodeId + for (int32_t i = 0; i < samples_num; ++i) { + neighbors.emplace_back(kDefaultNodeId); + } } *out_neighbors = std::move(neighbors); return Status::OK(); diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/dataset/engine/gnn/local_node.h index 25f24818e1..b9b007c420 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.h @@ -43,12 +43,19 @@ class LocalNode : public Node { // @return Status - The error code return Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; - // Get the neighbors of a node + // Get the all neighbors of a node // @param NodeType neighbor_type - type of neighbor - // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired // @param std::vector *out_neighbors - Returned neighbors id // @return Status - The error code return - Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector *out_neighbors) override; + Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) override; + + // Get the sampled neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) override; // Add neighbor of node // @param std::shared_ptr node - @@ -61,6 +68,10 @@ class LocalNode : public Node { Status UpdateFeature(const std::shared_ptr &feature) override; private: + Status GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out); + + std::mt19937 rnd_; std::unordered_map> features_; std::unordered_map>> neighbor_nodes_; }; diff --git a/mindspore/ccsrc/dataset/engine/gnn/node.h b/mindspore/ccsrc/dataset/engine/gnn/node.h index 8e3db51d65..f0136e92d7 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/dataset/engine/gnn/node.h @@ -52,12 +52,19 @@ class Node { // @return Status - The error code return virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; - // Get the neighbors of a node + // Get the all neighbors of a node // @param NodeType neighbor_type - type of neighbor - // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired // @param std::vector *out_neighbors - Returned neighbors id // @return Status - The error code return - virtual Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector *out_neighbors) = 0; + virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) = 0; + + // Get the sampled neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) = 0; // Add neighbor of node // @param std::shared_ptr node - diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 23f8dbda6a..573aa84e4b 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -20,8 +20,9 @@ import numpy as np from mindspore._c_dataengine import Graph from mindspore._c_dataengine import Tensor -from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \ - check_gnn_get_node_feature +from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ + check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ + check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature class GraphData: @@ -60,7 +61,44 @@ class GraphData: Raises: TypeError: If `node_type` is not integer. """ - return self._graph.get_nodes(node_type, -1).as_array() + return self._graph.get_all_nodes(node_type).as_array() + + @check_gnn_get_all_edges + def get_all_edges(self, edge_type): + """ + Get all edges in the graph. + + Args: + edge_type (int): Specify the type of edge. + + Returns: + numpy.ndarray: array of edges. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.get_all_edges(0) + + Raises: + TypeError: If `edge_type` is not integer. + """ + return self._graph.get_all_edges(edge_type).as_array() + + @check_gnn_get_nodes_from_edges + def get_nodes_from_edges(self, edge_list): + """ + Get nodes from the edges. + + Args: + edge_list (list or numpy.ndarray): The given list of edges. + + Returns: + numpy.ndarray: array of nodes. + + Raises: + TypeError: If `edge_list` is not list or ndarray. + """ + return self._graph.get_nodes_from_edges(edge_list).as_array() @check_gnn_get_all_neighbors def get_all_neighbors(self, node_list, neighbor_type): @@ -86,6 +124,58 @@ class GraphData: """ return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() + @check_gnn_get_sampled_neighbors + def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): + """ + Get sampled neighbor information, maximum support 6-hop sampling. + + Args: + node_list (list or numpy.ndarray): The given list of nodes. + neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop. + neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop. + + Returns: + numpy.ndarray: array of nodes. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.get_all_nodes(0) + >>> neighbors = data_graph.get_all_neighbors(nodes, [2, 2], [0, 0]) + + Raises: + TypeError: If `node_list` is not list or ndarray. + TypeError: If `neighbor_nums` is not list or ndarray. + TypeError: If `neighbor_types` is not list or ndarray. + """ + return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array() + + @check_gnn_get_neg_sampled_neighbors + def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): + """ + Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`. + + Args: + node_list (list or numpy.ndarray): The given list of nodes. + neg_neighbor_num (int): Number of neighbors sampled. + neg_neighbor_type (int): Specify the type of negative neighbor. + + Returns: + numpy.ndarray: array of nodes. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.get_all_nodes(0) + >>> neg_neighbors = data_graph.get_neg_sampled_neighbors(nodes, 5, 0) + + Raises: + TypeError: If `node_list` is not list or ndarray. + TypeError: If `neg_neighbor_num` is not integer. + TypeError: If `neg_neighbor_type` is not integer. + """ + return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array() + @check_gnn_get_node_feature def get_node_feature(self, node_list, feature_types): """ @@ -111,3 +201,13 @@ class GraphData: if isinstance(node_list, list): node_list = np.array(node_list, dtype=np.int32) return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] + + def graph_info(self): + """ + Get the meta information of the graph, including the number of nodes, the type of nodes, + the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. + Returns: + Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, + node_feature_type and edge_feature_type. + """ + return self._graph.graph_info() diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 94f2d0b8d5..eee4dde2bd 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1153,6 +1153,36 @@ def check_gnn_get_all_nodes(method): return new_method +def check_gnn_get_all_edges(method): + """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_type; required argument + check_type(param_dict.get("edge_type"), 'edge_type', int) + + return method(*args, **kwargs) + + return new_method + + +def check_gnn_get_nodes_from_edges(method): + """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check edge_list; required argument + check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list') + + return method(*args, **kwargs) + + return new_method + + def check_gnn_get_all_neighbors(method): """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" @@ -1171,6 +1201,61 @@ def check_gnn_get_all_neighbors(method): return new_method +def check_gnn_get_sampled_neighbors(method): + """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_list; required argument + check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + + # check neighbor_nums; required argument + neighbor_nums = param_dict.get("neighbor_nums") + check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') + if len(neighbor_nums) > 6: + raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( + 'neighbor_nums', len(neighbor_nums))) + + # check neighbor_types; required argument + neighbor_types = param_dict.get("neighbor_types") + check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') + if len(neighbor_nums) > 6: + raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( + 'neighbor_types', len(neighbor_types))) + + if len(neighbor_nums) != len(neighbor_types): + raise ValueError( + "The number of members of neighbor_nums and neighbor_types is inconsistent") + + return method(*args, **kwargs) + + return new_method + + +def check_gnn_get_neg_sampled_neighbors(method): + """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_list; required argument + check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + + # check neg_neighbor_num; required argument + check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int) + + # check neg_neighbor_type; required argument + check_type(param_dict.get("neg_neighbor_type"), + 'neg_neighbor_type', int) + + return method(*args, **kwargs) + + return new_method + + def check_aligned_list(param, param_name, membor_type): """Check whether the structure of each member of the list is the same.""" diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 0aefffe784..7c644a3ae7 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include +#include #include "common/common.h" #include "gtest/gtest.h" @@ -45,7 +47,7 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) { &default_feature_map) .IsOk()); EXPECT_EQ(n_id_map.size(), 20); - EXPECT_EQ(e_id_map.size(), 20); + EXPECT_EQ(e_id_map.size(), 40); EXPECT_EQ(n_type_map[2].size(), 10); EXPECT_EQ(n_type_map[1].size(), 10); } @@ -56,14 +58,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { Status s = graph.Init(); EXPECT_TRUE(s.IsOk()); - std::vector node_info; - std::vector edge_info; - s = graph.GetMetaInfo(&node_info, &edge_info); + MetaInfo meta_info; + s = graph.GetMetaInfo(&meta_info); EXPECT_TRUE(s.IsOk()); - EXPECT_TRUE(node_info.size() == 2); + EXPECT_TRUE(meta_info.node_type.size() == 2); std::shared_ptr nodes; - s = graph.GetNodes(node_info[1].type, -1, &nodes); + s = graph.GetAllNodes(meta_info.node_type[0], &nodes); EXPECT_TRUE(s.IsOk()); std::vector node_list; for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { @@ -73,13 +74,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { } } std::shared_ptr neighbors; - s = graph.GetAllNeighbors(node_list, node_info[0].type, &neighbors); + s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors); EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); TensorRow features; - s = graph.GetNodeFeature(nodes, node_info[1].feature_type, &features); + s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features); EXPECT_TRUE(s.IsOk()); - EXPECT_TRUE(features.size() == 3); + EXPECT_TRUE(features.size() == 4); EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>"); EXPECT_TRUE(features[0]->ToString() == "Tensor (shape: <10,5>, Type: int32)\n" @@ -91,3 +92,106 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { EXPECT_TRUE(features[2]->shape().ToString() == "<10>"); EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); } + +TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { + std::string path = "data/mindrecord/testGraphData/testdata"; + Graph graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + MetaInfo meta_info; + s = graph.GetMetaInfo(&meta_info); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(meta_info.node_type.size() == 2); + + std::shared_ptr edges; + s = graph.GetAllEdges(meta_info.edge_type[0], &edges); + EXPECT_TRUE(s.IsOk()); + std::vector edge_list; + edge_list.resize(edges->Size()); + std::transform(edges->begin(), edges->end(), edge_list.begin(), + [](const EdgeIdType edge) { return edge; }); + + std::shared_ptr nodes; + s = graph.GetNodesFromEdges(edge_list, &nodes); + EXPECT_TRUE(s.IsOk()); + std::unordered_set node_set; + std::vector node_list; + int index = 0; + for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { + index++; + if (index % 2 == 0) { + continue; + } + node_set.emplace(*itr); + if (node_set.size() >= 5) { + break; + } + } + node_list.resize(node_set.size()); + std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); + + std::shared_ptr neighbors; + s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); + + neighbors.reset(); + s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); + + neighbors.reset(); + s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, + {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); + + neighbors.reset(); + s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); + EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); + + neighbors.reset(); + s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); + EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != + std::string::npos); + + neighbors.reset(); + s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors); + EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); +} + +TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { + std::string path = "data/mindrecord/testGraphData/testdata"; + Graph graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + MetaInfo meta_info; + s = graph.GetMetaInfo(&meta_info); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(meta_info.node_type.size() == 2); + + std::shared_ptr nodes; + s = graph.GetAllNodes(meta_info.node_type[0], &nodes); + EXPECT_TRUE(s.IsOk()); + std::vector node_list; + for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { + node_list.push_back(*itr); + if (node_list.size() >= 10) { + break; + } + } + std::shared_ptr neg_neighbors; + s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>"); + + neg_neighbors.reset(); + s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors); + EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); + + neg_neighbors.reset(); + s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); + EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); +} diff --git a/tests/ut/data/mindrecord/testGraphData/testdata b/tests/ut/data/mindrecord/testGraphData/testdata index 8978131ee1..e206469ac6 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata and b/tests/ut/data/mindrecord/testGraphData/testdata differ diff --git a/tests/ut/data/mindrecord/testGraphData/testdata.db b/tests/ut/data/mindrecord/testGraphData/testdata.db index f846a67009..541da0e998 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata.db and b/tests/ut/data/mindrecord/testGraphData/testdata.db differ diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 4aa4fc89ee..9b4ff66ac1 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import random import pytest import numpy as np import mindspore.dataset as ds @@ -77,8 +78,110 @@ def test_graphdata_getnodefeature_input_check(): g.get_node_feature(input_list, [1, "a"]) +def test_graphdata_getsampledneighbors(): + g = ds.GraphData(DATASET_FILE, 1) + edges = g.get_all_edges(0) + nodes = g.get_nodes_from_edges(edges) + assert len(nodes) == 40 + neighbor = g.get_sampled_neighbors( + np.unique(nodes[0:21, 0]), [2, 3], [2, 1]) + assert neighbor.shape == (10, 9) + + +def test_graphdata_getnegsampledneighbors(): + g = ds.GraphData(DATASET_FILE, 2) + nodes = g.get_all_nodes(1) + assert len(nodes) == 10 + neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2) + assert neighbor.shape == (10, 6) + + +def test_graphdata_graphinfo(): + g = ds.GraphData(DATASET_FILE, 2) + graph_info = g.graph_info() + assert graph_info['node_type'] == [1, 2] + assert graph_info['edge_type'] == [0] + assert graph_info['node_num'] == {1: 10, 2: 10} + assert graph_info['edge_num'] == {0: 40} + assert graph_info['node_feature_type'] == [1, 2, 3, 4] + assert graph_info['edge_feature_type'] == [] + + +class RandomBatchedSampler(ds.Sampler): + # RandomBatchedSampler generate random sequence without replacement in a batched manner + def __init__(self, index_range, num_edges_per_sample): + super().__init__() + self.index_range = index_range + self.num_edges_per_sample = num_edges_per_sample + + def __iter__(self): + indices = [i+1 for i in range(self.index_range)] + # Reset random seed here if necessary + # random.seed(0) + random.shuffle(indices) + for i in range(0, self.index_range, self.num_edges_per_sample): + # Drop reminder + if i + self.num_edges_per_sample <= self.index_range: + yield indices[i: i + self.num_edges_per_sample] + + +class GNNGraphDataset(): + def __init__(self, g, batch_num): + self.g = g + self.batch_num = batch_num + + def __len__(self): + # Total sample size of GNN dataset + # In this case, the size should be total_num_edges/num_edges_per_sample + return self.g.graph_info()['edge_num'][0] // self.batch_num + + def __getitem__(self, index): + # index will be a list of indices yielded from RandomBatchedSampler + # Fetch edges/nodes/samples/features based on indices + nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) + nodes = nodes[:, 0] + neg_nodes = self.g.get_neg_sampled_neighbors( + node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) + nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ + 2, 2], neighbor_types=[2, 1]) + neg_nodes_neighbors = self.g.get_sampled_neighbors( + node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) + nodes_neighbors_features = self.g.get_node_feature( + node_list=nodes_neighbors, feature_types=[2, 3]) + neg_neighbors_features = self.g.get_node_feature( + node_list=neg_nodes_neighbors, feature_types=[2, 3]) + return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] + + +def test_graphdata_generatordataset(): + g = ds.GraphData(DATASET_FILE) + batch_num = 2 + edge_num = g.graph_info()['edge_num'][0] + out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] + dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, + sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4) + dataset = dataset.repeat(2) + itr = dataset.create_dict_iterator() + i = 0 + for data in itr: + assert data['neighbors'].shape == (2, 7) + assert data['neg_neighbors'].shape == (6, 7) + assert data['neighbors_features'].shape == (2, 7) + assert data['neg_neighbors_features'].shape == (6, 7) + i += 1 + assert i == 40 + + if __name__ == '__main__': test_graphdata_getfullneighbor() logger.info('test_graphdata_getfullneighbor Ended.\n') test_graphdata_getnodefeature_input_check() logger.info('test_graphdata_getnodefeature_input_check Ended.\n') + test_graphdata_getsampledneighbors() + logger.info('test_graphdata_getsampledneighbors Ended.\n') + test_graphdata_getnegsampledneighbors() + logger.info('test_graphdata_getnegsampledneighbors Ended.\n') + test_graphdata_graphinfo() + logger.info('test_graphdata_graphinfo Ended.\n') + test_graphdata_generatordataset() + logger.info('test_graphdata_generatordataset Ended.\n')