1. support get_all_edges, get_nodes_from_edge, get_sampled_neighbors, get_neg_sampled_neighbors and graph_info API

2. mod cora and citeseer conversion script
pull/1866/head
heleiwang 5 years ago
parent 444d9484d7
commit 3ece8dd090

@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t
1. Download and prepare the Cora dataset as required. 1. Download and prepare the Cora dataset as required.
> [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora)
2. Edit write_cora.sh and modify the parameters 2. Edit write_cora.sh and modify the parameters
``` ```
--mindrecord_file: output MindRecord file. --mindrecord_file: output MindRecord file.

@ -15,29 +15,26 @@
""" """
User-defined API for MindRecord GNN writer. User-defined API for MindRecord GNN writer.
""" """
import csv
import os import os
import pickle as pkl
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
# parse args from command line parameter 'graph_api_args' # parse args from command line parameter 'graph_api_args'
# args delimiter is ':' # args delimiter is ':'
args = os.environ['graph_api_args'].split(':') args = os.environ['graph_api_args'].split(':')
CITESEER_CONTENT_FILE = args[0] CITESEER_PATH = args[0]
CITESEER_CITES_FILE = args[1] dataset_str = 'citeseer'
CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord"
CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord"
node_id_map = {}
# profile: (num_features, feature_data_types, feature_shapes) # profile: (num_features, feature_data_types, feature_shapes)
node_profile = (2, ["float32", "int64"], [[-1], [-1]]) node_profile = (2, ["float32", "int32"], [[-1], [-1]])
edge_profile = (0, [], []) edge_profile = (0, [], [])
node_ids = []
def _normalize_citeseer_features(features): def _normalize_citeseer_features(features):
features = np.array(features)
row_sum = np.array(features.sum(1)) row_sum = np.array(features.sum(1))
r_inv = np.power(row_sum * 1.0, -1).flatten() r_inv = np.power(row_sum * 1.0, -1).flatten()
r_inv[np.isinf(r_inv)] = 0. r_inv[np.isinf(r_inv)] = 0.
@ -46,6 +43,14 @@ def _normalize_citeseer_features(features):
return features return features
def _parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def yield_nodes(task_id=0): def yield_nodes(task_id=0):
""" """
Generate node data Generate node data
@ -54,29 +59,46 @@ def yield_nodes(task_id=0):
data (dict): data row which is dict. data (dict): data row which is dict.
""" """
print("Node task is {}".format(task_id)) print("Node task is {}".format(task_id))
label_types = {} names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
label_size = 0 objects = []
node_num = 0 for name in names:
with open(CITESEER_CONTENT_FILE) as content_file: with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f:
content_reader = csv.reader(content_file, delimiter='\t') objects.append(pkl.load(f, encoding='latin1'))
x, y, tx, ty, allx, ally = tuple(objects)
test_idx_reorder = _parse_index_file(
"{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str))
test_idx_range = np.sort(test_idx_reorder)
tx = _normalize_citeseer_features(tx)
allx = _normalize_citeseer_features(allx)
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range-min(test_idx_range), :] = tx
tx = tx_extended
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
ty_extended[test_idx_range-min(test_idx_range), :] = ty
ty = ty_extended
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
features = features.A
labels = np.vstack((ally, ty))
labels[test_idx_reorder, :] = labels[test_idx_range, :]
line_count = 0 line_count = 0
for row in content_reader: for i, label in enumerate(labels):
if not row[-1] in label_types: if not 1 in label.tolist():
label_types[row[-1]] = label_size continue
label_size += 1 node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
if not row[0] in node_id_map: 'feature_2': label.tolist().index(1)}
node_id_map[row[0]] = node_num
node_num += 1
raw_features = [[int(x) for x in row[1:-1]]]
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features),
'feature_2': [label_types[row[-1]]]}
yield node
line_count += 1 line_count += 1
node_ids.append(i)
yield node
print('Processed {} lines for nodes.'.format(line_count)) print('Processed {} lines for nodes.'.format(line_count))
# print('label types {}.'.format(label_types))
with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f:
for k in label_types:
print(k + ',' + str(label_types[k]), file=f)
def yield_edges(task_id=0): def yield_edges(task_id=0):
@ -87,23 +109,20 @@ def yield_edges(task_id=0):
data (dict): data row which is dict. data (dict): data row which is dict.
""" """
print("Edge task is {}".format(task_id)) print("Edge task is {}".format(task_id))
# print(map_string_int) with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
with open(CITESEER_CITES_FILE) as cites_file: graph = pkl.load(f, encoding='latin1')
cites_reader = csv.reader(cites_file, delimiter='\t')
line_count = 0 line_count = 0
for row in cites_reader: for i in graph:
if not row[0] in node_id_map: for dst_id in graph[i]:
print('Source node {} does not exist.'.format(row[0])) if not i in node_ids:
print('Source node {} does not exist.'.format(i))
continue continue
if not row[1] in node_id_map: if not dst_id in node_ids:
print('Destination node {} does not exist.'.format(row[1])) print('Destination node {} does not exist.'.format(
dst_id))
continue continue
line_count += 1
edge = {'id': line_count, edge = {'id': line_count,
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} 'src_id': i, 'dst_id': dst_id, 'type': 0}
line_count += 1
yield edge yield edge
with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f:
for k in node_id_map:
print(k + ',' + str(node_id_map[k]), file=f)
print('Processed {} lines for edges.'.format(line_count)) print('Processed {} lines for edges.'.format(line_count))

@ -15,29 +15,24 @@
""" """
User-defined API for MindRecord GNN writer. User-defined API for MindRecord GNN writer.
""" """
import csv
import os import os
import pickle as pkl
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
# parse args from command line parameter 'graph_api_args' # parse args from command line parameter 'graph_api_args'
# args delimiter is ':' # args delimiter is ':'
args = os.environ['graph_api_args'].split(':') args = os.environ['graph_api_args'].split(':')
CORA_CONTENT_FILE = args[0] CORA_PATH = args[0]
CORA_CITES_FILE = args[1] dataset_str = 'cora'
CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord"
CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord"
node_id_map = {}
# profile: (num_features, feature_data_types, feature_shapes) # profile: (num_features, feature_data_types, feature_shapes)
node_profile = (2, ["float32", "int64"], [[-1], [-1]]) node_profile = (2, ["float32", "int32"], [[-1], [-1]])
edge_profile = (0, [], []) edge_profile = (0, [], [])
def _normalize_cora_features(features): def _normalize_cora_features(features):
features = np.array(features)
row_sum = np.array(features.sum(1)) row_sum = np.array(features.sum(1))
r_inv = np.power(row_sum * 1.0, -1).flatten() r_inv = np.power(row_sum * 1.0, -1).flatten()
r_inv[np.isinf(r_inv)] = 0. r_inv[np.isinf(r_inv)] = 0.
@ -46,6 +41,14 @@ def _normalize_cora_features(features):
return features return features
def _parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def yield_nodes(task_id=0): def yield_nodes(task_id=0):
""" """
Generate node data Generate node data
@ -54,32 +57,32 @@ def yield_nodes(task_id=0):
data (dict): data row which is dict. data (dict): data row which is dict.
""" """
print("Node task is {}".format(task_id)) print("Node task is {}".format(task_id))
label_types = {}
label_size = 0 names = ['tx', 'ty', 'allx', 'ally']
node_num = 0 objects = []
with open(CORA_CONTENT_FILE) as content_file: for name in names:
content_reader = csv.reader(content_file, delimiter=',') with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f:
objects.append(pkl.load(f, encoding='latin1'))
tx, ty, allx, ally = tuple(objects)
test_idx_reorder = _parse_index_file(
"{}/ind.{}.test.index".format(CORA_PATH, dataset_str))
test_idx_range = np.sort(test_idx_reorder)
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
features = _normalize_cora_features(features)
features = features.A
labels = np.vstack((ally, ty))
labels[test_idx_reorder, :] = labels[test_idx_range, :]
line_count = 0 line_count = 0
for row in content_reader: for i, label in enumerate(labels):
if line_count == 0: node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
'feature_2': label.tolist().index(1)}
line_count += 1 line_count += 1
continue
if not row[0] in node_id_map:
node_id_map[row[0]] = node_num
node_num += 1
if not row[-1] in label_types:
label_types[row[-1]] = label_size
label_size += 1
raw_features = [[int(x) for x in row[1:-1]]]
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features),
'feature_2': [label_types[row[-1]]]}
yield node yield node
line_count += 1
print('Processed {} lines for nodes.'.format(line_count)) print('Processed {} lines for nodes.'.format(line_count))
print('label types {}.'.format(label_types))
with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f:
for k in label_types:
print(k + ',' + str(label_types[k]), file=f)
def yield_edges(task_id=0): def yield_edges(task_id=0):
@ -90,24 +93,13 @@ def yield_edges(task_id=0):
data (dict): data row which is dict. data (dict): data row which is dict.
""" """
print("Edge task is {}".format(task_id)) print("Edge task is {}".format(task_id))
with open(CORA_CITES_FILE) as cites_file: with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f:
cites_reader = csv.reader(cites_file, delimiter=',') graph = pkl.load(f, encoding='latin1')
line_count = 0 line_count = 0
for row in cites_reader: for i in graph:
if line_count == 0: for dst_id in graph[i]:
line_count += 1
continue
if not row[0] in node_id_map:
print('Source node {} does not exist.'.format(row[0]))
continue
if not row[1] in node_id_map:
print('Destination node {} does not exist.'.format(row[1]))
continue
edge = {'id': line_count, edge = {'id': line_count,
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} 'src_id': i, 'dst_id': dst_id, 'type': 0}
yield edge
line_count += 1 line_count += 1
yield edge
print('Processed {} lines for edges.'.format(line_count)) print('Processed {} lines for edges.'.format(line_count))
with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f:
for k in node_id_map:
print(k + ',' + str(node_id_map[k]), file=f)

@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \
--mindrecord_partitions 1 \ --mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 18 \ --mindrecord_header_size_by_bit 18 \
--mindrecord_page_size_by_bit 20 \ --mindrecord_page_size_by_bit 20 \
--graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites" --graph_api_args "$SRC_PATH"

@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \
--mindrecord_partitions 1 \ --mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 18 \ --mindrecord_header_size_by_bit 18 \
--mindrecord_page_size_by_bit 20 \ --mindrecord_page_size_by_bit 20 \
--graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv" --graph_api_args "$SRC_PATH"

@ -527,10 +527,22 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR(g_out->Init()); THROW_IF_ERROR(g_out->Init());
return g_out; return g_out;
})) }))
.def("get_nodes", .def("get_all_nodes",
[](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) { [](gnn::Graph &g, gnn::NodeType node_type) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out)); THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
return out;
})
.def("get_all_edges",
[](gnn::Graph &g, gnn::EdgeType edge_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
return out;
})
.def("get_nodes_from_edges",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out; return out;
}) })
.def("get_all_neighbors", .def("get_all_neighbors",
@ -539,11 +551,30 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
return out; return out;
}) })
.def("get_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
return out;
})
.def("get_neg_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
gnn::NodeType neg_neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
return out;
})
.def("get_node_feature", .def("get_node_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out; TensorRow out;
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out; return out;
})
.def("graph_info", [](gnn::Graph &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
return out;
}); });
} }

File diff suppressed because it is too large Load Diff

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -33,24 +34,13 @@ namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {
struct NodeMetaInfo { struct MetaInfo {
NodeType type; std::vector<NodeType> node_type;
NodeIdType num; std::vector<EdgeType> edge_type;
std::vector<FeatureType> feature_type; std::map<NodeType, NodeIdType> node_num;
NodeMetaInfo() { std::map<EdgeType, EdgeIdType> edge_num;
type = 0; std::vector<FeatureType> node_feature_type;
num = 0; std::vector<FeatureType> edge_feature_type;
}
};
struct EdgeMetaInfo {
EdgeType type;
EdgeIdType num;
std::vector<FeatureType> feature_type;
EdgeMetaInfo() {
type = 0;
num = 0;
}
}; };
class Graph { class Graph {
@ -62,19 +52,23 @@ class Graph {
~Graph() = default; ~Graph() = default;
// Get the nodes from the graph. // Get all nodes from the graph.
// @param NodeType node_type - type of node // @param NodeType node_type - type of node
// @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired
// @param std::shared_ptr<Tensor> *out - Returned nodes id // @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return // @return Status - The error code return
Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out); Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out);
// Get the edges from the graph. // Get all edges from the graph.
// @param NodeType edge_type - type of edge // @param NodeType edge_type - type of edge
// @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired
// @param std::shared_ptr<Tensor> *out - Returned edge ids // @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return // @return Status - The error code return
Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out); Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out);
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out);
// All neighbors of the acquisition node. // All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes
@ -86,10 +80,24 @@ class Graph {
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out); std::shared_ptr<Tensor> *out);
Status GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, // Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
Status GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q, Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q,
NodeIdType default_node, std::shared_ptr<Tensor> *out); NodeIdType default_node, std::shared_ptr<Tensor> *out);
@ -112,10 +120,12 @@ class Graph {
TensorRow *out); TensorRow *out);
// Get meta information of graph // Get meta information of graph
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of node // @param MetaInfo *meta_info - Returned meta information
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of edge
// @return Status - The error code return // @return Status - The error code return
Status GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info); Status GetMetaInfo(MetaInfo *meta_info);
// Return meta information to python layer
Status GraphInfo(py::dict *out);
Status Init(); Status Init();
@ -146,8 +156,24 @@ class Graph {
// @return Status - The error code return // @return Status - The error code return
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
// Find node object using node id
// @param NodeIdType id -
// @param std::shared_ptr<Node> *node - Returned node object
// @return Status - The error code return
Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node);
// Negative sampling
// @param std::vector<NodeIdType> &input_data - The data set to be sampled
// @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
// @param int32_t samples_num -
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
// @return Status - The error code return
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
int32_t samples_num, std::vector<NodeIdType> *out_samples);
std::string dataset_file_; std::string dataset_file_;
int32_t num_workers_; // The number of worker threads int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_;
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;

@ -20,12 +20,13 @@
#include <utility> #include <utility>
#include "dataset/engine/gnn/edge.h" #include "dataset/engine/gnn/edge.h"
#include "dataset/util/random.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {
LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type) {} LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); }
Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = features_.find(feature_type); auto itr = features_.find(feature_type);
@ -38,21 +39,49 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
} }
} }
Status LocalNode::GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) { Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) {
std::vector<NodeIdType> neighbors; std::vector<NodeIdType> neighbors;
auto itr = neighbor_nodes_.find(neighbor_type); auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) { if (itr != neighbor_nodes_.end()) {
if (samples_num == -1) {
// Return all neighbors
neighbors.resize(itr->second.size() + 1); neighbors.resize(itr->second.size() + 1);
neighbors[0] = id_; neighbors[0] = id_;
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
[](const std::shared_ptr<Node> node) { return node->id(); }); [](const std::shared_ptr<Node> node) { return node->id(); });
} else { } else {
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
neighbors.emplace_back(id_);
}
*out_neighbors = std::move(neighbors);
return Status::OK();
}
Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
std::vector<NodeIdType> *out) {
std::vector<NodeIdType> shuffled_id(neighbors.size());
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
int32_t num = std::min(samples_num, static_cast<int32_t>(neighbors.size()));
for (int32_t i = 0; i < num; ++i) {
out->emplace_back(neighbors[shuffled_id[i]]->id());
}
return Status::OK();
}
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
std::vector<NodeIdType> *out_neighbors) {
std::vector<NodeIdType> neighbors;
neighbors.reserve(samples_num);
auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) {
while (neighbors.size() < samples_num) {
RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors));
} }
} else { } else {
neighbors.push_back(id_); MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; // If there are no neighbors, they are filled with kDefaultNodeId
for (int32_t i = 0; i < samples_num; ++i) {
neighbors.emplace_back(kDefaultNodeId);
}
} }
*out_neighbors = std::move(neighbors); *out_neighbors = std::move(neighbors);
return Status::OK(); return Status::OK();

@ -43,12 +43,19 @@ class LocalNode : public Node {
// @return Status - The error code return // @return Status - The error code return
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
// Get the neighbors of a node // Get the all neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return // @return Status - The error code return
Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) override; Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override;
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
std::vector<NodeIdType> *out_neighbors) override;
// Add neighbor of node // Add neighbor of node
// @param std::shared_ptr<Node> node - // @param std::shared_ptr<Node> node -
@ -61,6 +68,10 @@ class LocalNode : public Node {
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
private: private:
Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
std::vector<NodeIdType> *out);
std::mt19937 rnd_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_; std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_;
}; };

@ -52,12 +52,19 @@ class Node {
// @return Status - The error code return // @return Status - The error code return
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
// Get the neighbors of a node // Get the all neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return // @return Status - The error code return
virtual Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) = 0; virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0;
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
std::vector<NodeIdType> *out_neighbors) = 0;
// Add neighbor of node // Add neighbor of node
// @param std::shared_ptr<Node> node - // @param std::shared_ptr<Node> node -

@ -20,8 +20,9 @@ import numpy as np
from mindspore._c_dataengine import Graph from mindspore._c_dataengine import Graph
from mindspore._c_dataengine import Tensor from mindspore._c_dataengine import Tensor
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_node_feature check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature
class GraphData: class GraphData:
@ -60,7 +61,44 @@ class GraphData:
Raises: Raises:
TypeError: If `node_type` is not integer. TypeError: If `node_type` is not integer.
""" """
return self._graph.get_nodes(node_type, -1).as_array() return self._graph.get_all_nodes(node_type).as_array()
@check_gnn_get_all_edges
def get_all_edges(self, edge_type):
"""
Get all edges in the graph.
Args:
edge_type (int): Specify the type of edge.
Returns:
numpy.ndarray: array of edges.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.get_all_edges(0)
Raises:
TypeError: If `edge_type` is not integer.
"""
return self._graph.get_all_edges(edge_type).as_array()
@check_gnn_get_nodes_from_edges
def get_nodes_from_edges(self, edge_list):
"""
Get nodes from the edges.
Args:
edge_list (list or numpy.ndarray): The given list of edges.
Returns:
numpy.ndarray: array of nodes.
Raises:
TypeError: If `edge_list` is not list or ndarray.
"""
return self._graph.get_nodes_from_edges(edge_list).as_array()
@check_gnn_get_all_neighbors @check_gnn_get_all_neighbors
def get_all_neighbors(self, node_list, neighbor_type): def get_all_neighbors(self, node_list, neighbor_type):
@ -86,6 +124,58 @@ class GraphData:
""" """
return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() return self._graph.get_all_neighbors(node_list, neighbor_type).as_array()
@check_gnn_get_sampled_neighbors
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
"""
Get sampled neighbor information, maximum support 6-hop sampling.
Args:
node_list (list or numpy.ndarray): The given list of nodes.
neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop.
neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop.
Returns:
numpy.ndarray: array of nodes.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.get_all_nodes(0)
>>> neighbors = data_graph.get_all_neighbors(nodes, [2, 2], [0, 0])
Raises:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray.
"""
return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array()
@check_gnn_get_neg_sampled_neighbors
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
"""
Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`.
Args:
node_list (list or numpy.ndarray): The given list of nodes.
neg_neighbor_num (int): Number of neighbors sampled.
neg_neighbor_type (int): Specify the type of negative neighbor.
Returns:
numpy.ndarray: array of nodes.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.get_all_nodes(0)
>>> neg_neighbors = data_graph.get_neg_sampled_neighbors(nodes, 5, 0)
Raises:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `neg_neighbor_num` is not integer.
TypeError: If `neg_neighbor_type` is not integer.
"""
return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array()
@check_gnn_get_node_feature @check_gnn_get_node_feature
def get_node_feature(self, node_list, feature_types): def get_node_feature(self, node_list, feature_types):
""" """
@ -111,3 +201,13 @@ class GraphData:
if isinstance(node_list, list): if isinstance(node_list, list):
node_list = np.array(node_list, dtype=np.int32) node_list = np.array(node_list, dtype=np.int32)
return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)]
def graph_info(self):
"""
Get the meta information of the graph, including the number of nodes, the type of nodes,
the feature information of nodes, the number of edges, the type of edges, and the feature information of edges.
Returns:
Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type.
"""
return self._graph.graph_info()

@ -1153,6 +1153,36 @@ def check_gnn_get_all_nodes(method):
return new_method return new_method
def check_gnn_get_all_edges(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check node_type; required argument
check_type(param_dict.get("edge_type"), 'edge_type', int)
return method(*args, **kwargs)
return new_method
def check_gnn_get_nodes_from_edges(method):
"""A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check edge_list; required argument
check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list')
return method(*args, **kwargs)
return new_method
def check_gnn_get_all_neighbors(method): def check_gnn_get_all_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
@ -1171,6 +1201,61 @@ def check_gnn_get_all_neighbors(method):
return new_method return new_method
def check_gnn_get_sampled_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check node_list; required argument
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
# check neighbor_nums; required argument
neighbor_nums = param_dict.get("neighbor_nums")
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
if len(neighbor_nums) > 6:
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
'neighbor_nums', len(neighbor_nums)))
# check neighbor_types; required argument
neighbor_types = param_dict.get("neighbor_types")
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
if len(neighbor_nums) > 6:
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
'neighbor_types', len(neighbor_types)))
if len(neighbor_nums) != len(neighbor_types):
raise ValueError(
"The number of members of neighbor_nums and neighbor_types is inconsistent")
return method(*args, **kwargs)
return new_method
def check_gnn_get_neg_sampled_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check node_list; required argument
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
# check neg_neighbor_num; required argument
check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int)
# check neg_neighbor_type; required argument
check_type(param_dict.get("neg_neighbor_type"),
'neg_neighbor_type', int)
return method(*args, **kwargs)
return new_method
def check_aligned_list(param, param_name, membor_type): def check_aligned_list(param, param_name, membor_type):
"""Check whether the structure of each member of the list is the same.""" """Check whether the structure of each member of the list is the same."""

@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <algorithm>
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_set>
#include "common/common.h" #include "common/common.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
@ -45,7 +47,7 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
&default_feature_map) &default_feature_map)
.IsOk()); .IsOk());
EXPECT_EQ(n_id_map.size(), 20); EXPECT_EQ(n_id_map.size(), 20);
EXPECT_EQ(e_id_map.size(), 20); EXPECT_EQ(e_id_map.size(), 40);
EXPECT_EQ(n_type_map[2].size(), 10); EXPECT_EQ(n_type_map[2].size(), 10);
EXPECT_EQ(n_type_map[1].size(), 10); EXPECT_EQ(n_type_map[1].size(), 10);
} }
@ -56,14 +58,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
Status s = graph.Init(); Status s = graph.Init();
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
std::vector<NodeMetaInfo> node_info; MetaInfo meta_info;
std::vector<EdgeMetaInfo> edge_info; s = graph.GetMetaInfo(&meta_info);
s = graph.GetMetaInfo(&node_info, &edge_info);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(node_info.size() == 2); EXPECT_TRUE(meta_info.node_type.size() == 2);
std::shared_ptr<Tensor> nodes; std::shared_ptr<Tensor> nodes;
s = graph.GetNodes(node_info[1].type, -1, &nodes); s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
std::vector<NodeIdType> node_list; std::vector<NodeIdType> node_list;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
@ -73,13 +74,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
} }
} }
std::shared_ptr<Tensor> neighbors; std::shared_ptr<Tensor> neighbors;
s = graph.GetAllNeighbors(node_list, node_info[0].type, &neighbors); s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
TensorRow features; TensorRow features;
s = graph.GetNodeFeature(nodes, node_info[1].feature_type, &features); s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(features.size() == 3); EXPECT_TRUE(features.size() == 4);
EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>"); EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>");
EXPECT_TRUE(features[0]->ToString() == EXPECT_TRUE(features[0]->ToString() ==
"Tensor (shape: <10,5>, Type: int32)\n" "Tensor (shape: <10,5>, Type: int32)\n"
@ -91,3 +92,106 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
EXPECT_TRUE(features[2]->shape().ToString() == "<10>"); EXPECT_TRUE(features[2]->shape().ToString() == "<10>");
EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
} }
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
MetaInfo meta_info;
s = graph.GetMetaInfo(&meta_info);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(meta_info.node_type.size() == 2);
std::shared_ptr<Tensor> edges;
s = graph.GetAllEdges(meta_info.edge_type[0], &edges);
EXPECT_TRUE(s.IsOk());
std::vector<EdgeIdType> edge_list;
edge_list.resize(edges->Size());
std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(),
[](const EdgeIdType edge) { return edge; });
std::shared_ptr<Tensor> nodes;
s = graph.GetNodesFromEdges(edge_list, &nodes);
EXPECT_TRUE(s.IsOk());
std::unordered_set<NodeIdType> node_set;
std::vector<NodeIdType> node_list;
int index = 0;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
index++;
if (index % 2 == 0) {
continue;
}
node_set.emplace(*itr);
if (node_set.size() >= 5) {
break;
}
}
node_list.resize(node_set.size());
std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
std::shared_ptr<Tensor> neighbors;
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
{meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
neighbors.reset();
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors);
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
}
TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
MetaInfo meta_info;
s = graph.GetMetaInfo(&meta_info);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(meta_info.node_type.size() == 2);
std::shared_ptr<Tensor> nodes;
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
EXPECT_TRUE(s.IsOk());
std::vector<NodeIdType> node_list;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
node_list.push_back(*itr);
if (node_list.size() >= 10) {
break;
}
}
std::shared_ptr<Tensor> neg_neighbors;
s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>");
neg_neighbors.reset();
s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
neg_neighbors.reset();
s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos);
}

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import random
import pytest import pytest
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
@ -77,8 +78,110 @@ def test_graphdata_getnodefeature_input_check():
g.get_node_feature(input_list, [1, "a"]) g.get_node_feature(input_list, [1, "a"])
def test_graphdata_getsampledneighbors():
g = ds.GraphData(DATASET_FILE, 1)
edges = g.get_all_edges(0)
nodes = g.get_nodes_from_edges(edges)
assert len(nodes) == 40
neighbor = g.get_sampled_neighbors(
np.unique(nodes[0:21, 0]), [2, 3], [2, 1])
assert neighbor.shape == (10, 9)
def test_graphdata_getnegsampledneighbors():
g = ds.GraphData(DATASET_FILE, 2)
nodes = g.get_all_nodes(1)
assert len(nodes) == 10
neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2)
assert neighbor.shape == (10, 6)
def test_graphdata_graphinfo():
g = ds.GraphData(DATASET_FILE, 2)
graph_info = g.graph_info()
assert graph_info['node_type'] == [1, 2]
assert graph_info['edge_type'] == [0]
assert graph_info['node_num'] == {1: 10, 2: 10}
assert graph_info['edge_num'] == {0: 40}
assert graph_info['node_feature_type'] == [1, 2, 3, 4]
assert graph_info['edge_feature_type'] == []
class RandomBatchedSampler(ds.Sampler):
# RandomBatchedSampler generate random sequence without replacement in a batched manner
def __init__(self, index_range, num_edges_per_sample):
super().__init__()
self.index_range = index_range
self.num_edges_per_sample = num_edges_per_sample
def __iter__(self):
indices = [i+1 for i in range(self.index_range)]
# Reset random seed here if necessary
# random.seed(0)
random.shuffle(indices)
for i in range(0, self.index_range, self.num_edges_per_sample):
# Drop reminder
if i + self.num_edges_per_sample <= self.index_range:
yield indices[i: i + self.num_edges_per_sample]
class GNNGraphDataset():
def __init__(self, g, batch_num):
self.g = g
self.batch_num = batch_num
def __len__(self):
# Total sample size of GNN dataset
# In this case, the size should be total_num_edges/num_edges_per_sample
return self.g.graph_info()['edge_num'][0] // self.batch_num
def __getitem__(self, index):
# index will be a list of indices yielded from RandomBatchedSampler
# Fetch edges/nodes/samples/features based on indices
nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
nodes = nodes[:, 0]
neg_nodes = self.g.get_neg_sampled_neighbors(
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
2, 2], neighbor_types=[2, 1])
neg_nodes_neighbors = self.g.get_sampled_neighbors(
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
nodes_neighbors_features = self.g.get_node_feature(
node_list=nodes_neighbors, feature_types=[2, 3])
neg_neighbors_features = self.g.get_node_feature(
node_list=neg_nodes_neighbors, feature_types=[2, 3])
return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
def test_graphdata_generatordataset():
g = ds.GraphData(DATASET_FILE)
batch_num = 2
edge_num = g.graph_info()['edge_num'][0]
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4)
dataset = dataset.repeat(2)
itr = dataset.create_dict_iterator()
i = 0
for data in itr:
assert data['neighbors'].shape == (2, 7)
assert data['neg_neighbors'].shape == (6, 7)
assert data['neighbors_features'].shape == (2, 7)
assert data['neg_neighbors_features'].shape == (6, 7)
i += 1
assert i == 40
if __name__ == '__main__': if __name__ == '__main__':
test_graphdata_getfullneighbor() test_graphdata_getfullneighbor()
logger.info('test_graphdata_getfullneighbor Ended.\n') logger.info('test_graphdata_getfullneighbor Ended.\n')
test_graphdata_getnodefeature_input_check() test_graphdata_getnodefeature_input_check()
logger.info('test_graphdata_getnodefeature_input_check Ended.\n') logger.info('test_graphdata_getnodefeature_input_check Ended.\n')
test_graphdata_getsampledneighbors()
logger.info('test_graphdata_getsampledneighbors Ended.\n')
test_graphdata_getnegsampledneighbors()
logger.info('test_graphdata_getnegsampledneighbors Ended.\n')
test_graphdata_graphinfo()
logger.info('test_graphdata_graphinfo Ended.\n')
test_graphdata_generatordataset()
logger.info('test_graphdata_generatordataset Ended.\n')

Loading…
Cancel
Save