!1249 Add GNN dataset processing API

Merge pull request !1249 from heleiwang/hlw_gnn_data
pull/1249/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 3363d4e834

@ -0,0 +1,73 @@
# Guideline to Efficiently Generating MindRecord
<!-- TOC -->
- [What does the example do](#what-does-the-example-do)
- [Example test for Cora](#example-test-for-cora)
- [How to use the example for other dataset](#how-to-use-the-example-for-other-dataset)
- [Create work space](#create-work-space)
- [Implement data generator](#implement-data-generator)
- [Run data generator](#run-data-generator)
<!-- /TOC -->
## What does the example do
This example provides an efficient way to generate MindRecord. Users only need to define the parallel granularity of training data reading and the data reading function of a single task. That is, they can efficiently convert the user's training data into MindRecord.
1. write_cora.sh: entry script, users need to modify parameters according to their own training data.
2. writer.py: main script, called by write_cora.sh, it mainly reads user training data in parallel and generates MindRecord.
3. cora/mr_api.py: uers define their own parallel granularity of training data reading and single task reading function through the cora.
## Example test for Cora
1. Download and prepare the Cora dataset as required.
> [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora)
2. Edit write_cora.sh and modify the parameters
```
--mindrecord_file: output MindRecord file.
--mindrecord_partitions: the partitions for MindRecord.
```
3. Run the bash script
```bash
bash write_cora.sh
```
## How to use the example for other dataset
### Create work space
Assume the dataset name is 'xyz'
* Create work space from cora
```shell
cd ${your_mindspore_home}/example/graph_to_mindrecord
cp -r cora xyz
```
### Implement data generator
Edit dictionary data generator.
* Edit file
```shell
cd ${your_mindspore_home}/example/graph_to_mindrecord
vi xyz/mr_api.py
```
Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented.
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()
### Run data generator
* run python script
```shell
cd ${your_mindspore_home}/example/graph_to_mindrecord
python writer.py --mindrecord_script xyz [...]
```
> You can put this command in script **write_xyz.sh** for easy execution

@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord GNN writer.
"""
import csv
import os
import numpy as np
import scipy.sparse as sp
# parse args from command line parameter 'graph_api_args'
# args delimiter is ':'
args = os.environ['graph_api_args'].split(':')
CITESEER_CONTENT_FILE = args[0]
CITESEER_CITES_FILE = args[1]
CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord"
CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord"
node_id_map = {}
# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
edge_profile = (0, [], [])
def _normalize_citeseer_features(features):
features = np.array(features)
row_sum = np.array(features.sum(1))
r_inv = np.power(row_sum * 1.0, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features)
return features
def yield_nodes(task_id=0):
"""
Generate node data
Yields:
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
label_types = {}
label_size = 0
node_num = 0
with open(CITESEER_CONTENT_FILE) as content_file:
content_reader = csv.reader(content_file, delimiter='\t')
line_count = 0
for row in content_reader:
if not row[-1] in label_types:
label_types[row[-1]] = label_size
label_size += 1
if not row[0] in node_id_map:
node_id_map[row[0]] = node_num
node_num += 1
raw_features = [[int(x) for x in row[1:-1]]]
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features),
'feature_2': [label_types[row[-1]]]}
yield node
line_count += 1
print('Processed {} lines for nodes.'.format(line_count))
# print('label types {}.'.format(label_types))
with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f:
for k in label_types:
print(k + ',' + str(label_types[k]), file=f)
def yield_edges(task_id=0):
"""
Generate edge data
Yields:
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
# print(map_string_int)
with open(CITESEER_CITES_FILE) as cites_file:
cites_reader = csv.reader(cites_file, delimiter='\t')
line_count = 0
for row in cites_reader:
if not row[0] in node_id_map:
print('Source node {} does not exist.'.format(row[0]))
continue
if not row[1] in node_id_map:
print('Destination node {} does not exist.'.format(row[1]))
continue
line_count += 1
edge = {'id': line_count,
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
yield edge
with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f:
for k in node_id_map:
print(k + ',' + str(node_id_map[k]), file=f)
print('Processed {} lines for edges.'.format(line_count))

@ -0,0 +1,113 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord GNN writer.
"""
import csv
import os
import numpy as np
import scipy.sparse as sp
# parse args from command line parameter 'graph_api_args'
# args delimiter is ':'
args = os.environ['graph_api_args'].split(':')
CORA_CONTENT_FILE = args[0]
CORA_CITES_FILE = args[1]
CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord"
CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord"
node_id_map = {}
# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
edge_profile = (0, [], [])
def _normalize_cora_features(features):
features = np.array(features)
row_sum = np.array(features.sum(1))
r_inv = np.power(row_sum * 1.0, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features)
return features
def yield_nodes(task_id=0):
"""
Generate node data
Yields:
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
label_types = {}
label_size = 0
node_num = 0
with open(CORA_CONTENT_FILE) as content_file:
content_reader = csv.reader(content_file, delimiter=',')
line_count = 0
for row in content_reader:
if line_count == 0:
line_count += 1
continue
if not row[0] in node_id_map:
node_id_map[row[0]] = node_num
node_num += 1
if not row[-1] in label_types:
label_types[row[-1]] = label_size
label_size += 1
raw_features = [[int(x) for x in row[1:-1]]]
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features),
'feature_2': [label_types[row[-1]]]}
yield node
line_count += 1
print('Processed {} lines for nodes.'.format(line_count))
print('label types {}.'.format(label_types))
with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f:
for k in label_types:
print(k + ',' + str(label_types[k]), file=f)
def yield_edges(task_id=0):
"""
Generate edge data
Yields:
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
with open(CORA_CITES_FILE) as cites_file:
cites_reader = csv.reader(cites_file, delimiter=',')
line_count = 0
for row in cites_reader:
if line_count == 0:
line_count += 1
continue
if not row[0] in node_id_map:
print('Source node {} does not exist.'.format(row[0]))
continue
if not row[1] in node_id_map:
print('Destination node {} does not exist.'.format(row[1]))
continue
edge = {'id': line_count,
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
yield edge
line_count += 1
print('Processed {} lines for edges.'.format(line_count))
with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f:
for k in node_id_map:
print(k + ',' + str(node_id_map[k]), file=f)

@ -0,0 +1,2 @@
#!/bin/bash
python reader.py --path "/tmp/citeseer/mindrecord/citeseer_mr"

@ -0,0 +1,2 @@
#!/bin/bash
python reader.py --path "/tmp/cora/mindrecord/cora_mr"

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
######################## write mindrecord example ########################
Write mindrecord by data dictionary:
python writer.py --mindrecord_script /YourScriptPath ...
"""
import argparse
import mindspore.dataset as ds
parser = argparse.ArgumentParser(description='Mind record reader')
parser.add_argument('--path', type=str, default="/tmp/cora/mindrecord/cora_mr",
help='data file')
args = parser.parse_args()
data_set = ds.MindDataset(args.path)
num_iter = 0
for item in data_set.create_dict_iterator():
print(item)
num_iter += 1
print("Total items # is {}".format(num_iter))

@ -0,0 +1,9 @@
#!/bin/bash
rm /tmp/citeseer/mindrecord/*
python writer.py --mindrecord_script citeseer \
--mindrecord_file "/tmp/citeseer/mindrecord/citeseer_mr" \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 18 \
--mindrecord_page_size_by_bit 20 \
--graph_api_args "/tmp/citeseer/dataset/citeseer.content:/tmp/citeseer/dataset/citeseer.cites"

@ -0,0 +1,9 @@
#!/bin/bash
rm /tmp/cora/mindrecord/*
python writer.py --mindrecord_script cora \
--mindrecord_file "/tmp/cora/mindrecord/cora_mr" \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 18 \
--mindrecord_page_size_by_bit 20 \
--graph_api_args "/tmp/cora/dataset/cora_content.csv:/tmp/cora/dataset/cora_cites.csv"

@ -0,0 +1,186 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
######################## write mindrecord example ########################
Write mindrecord by data dictionary:
python writer.py --mindrecord_script /YourScriptPath ...
"""
import argparse
import os
import time
from importlib import import_module
from multiprocessing import Pool
from mindspore.mindrecord import FileWriter
from mindspore.mindrecord import GraphMapSchema
def exec_task(task_id, parallel_writer=True):
"""
Execute task with specified task id
"""
print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
imagenet_iter = mindrecord_dict_data(task_id)
batch_size = 512
transform_count = 0
while True:
data_list = []
try:
for _ in range(batch_size):
data = imagenet_iter.__next__()
if 'dst_id' in data:
data = graph_map_schema.transform_edge(data)
else:
data = graph_map_schema.transform_node(data)
data_list.append(data)
transform_count += 1
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
except StopIteration:
if data_list:
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
break
def read_args():
"""
read args
"""
parser = argparse.ArgumentParser(description='Mind record writer')
parser.add_argument('--mindrecord_script', type=str, default="template",
help='path where script is saved')
parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord/xyz",
help='written file name prefix')
parser.add_argument('--mindrecord_partitions', type=int, default=1,
help='number of written files')
parser.add_argument('--mindrecord_header_size_by_bit', type=int, default=24,
help='mindrecord file header size')
parser.add_argument('--mindrecord_page_size_by_bit', type=int, default=25,
help='mindrecord file page size')
parser.add_argument('--mindrecord_workers', type=int, default=8,
help='number of parallel workers')
parser.add_argument('--num_node_tasks', type=int, default=1,
help='number of node tasks')
parser.add_argument('--num_edge_tasks', type=int, default=1,
help='number of node tasks')
parser.add_argument('--graph_api_args', type=str, default="/tmp/nodes.csv:/tmp/edges.csv",
help='nodes and edges data file, csv format with header.')
ret_args = parser.parse_args()
return ret_args
def init_writer(mr_schema):
"""
init writer
"""
print("Init writer ...")
mr_writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
# set the header size
if args.mindrecord_header_size_by_bit != 24:
header_size = 1 << args.mindrecord_header_size_by_bit
mr_writer.set_header_size(header_size)
# set the page size
if args.mindrecord_page_size_by_bit != 25:
page_size = 1 << args.mindrecord_page_size_by_bit
mr_writer.set_page_size(page_size)
# create the schema
mr_writer.add_schema(mr_schema, "mindrecord_graph_schema")
# open file and set header
mr_writer.open_and_set_header()
return mr_writer
def run_parallel_workers(num_tasks):
"""
run parallel workers
"""
# set number of workers
num_workers = args.mindrecord_workers
task_list = list(range(num_tasks))
if num_workers > num_tasks:
num_workers = num_tasks
if os.name == 'nt':
for window_task_id in task_list:
exec_task(window_task_id, False)
elif num_tasks > 1:
with Pool(num_workers) as p:
p.map(exec_task, task_list)
else:
exec_task(0, False)
if __name__ == "__main__":
args = read_args()
print(args)
start_time = time.time()
# pass mr_api arguments
os.environ['graph_api_args'] = args.graph_api_args
# import mr_api
try:
mr_api = import_module(args.mindrecord_script + '.mr_api')
except ModuleNotFoundError:
raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
# init graph schema
graph_map_schema = GraphMapSchema()
num_features, feature_data_types, feature_shapes = mr_api.node_profile
graph_map_schema.set_node_feature_profile(num_features, feature_data_types, feature_shapes)
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
graph_schema = graph_map_schema.get_schema()
# init writer
writer = init_writer(graph_schema)
# write nodes data
mindrecord_dict_data = mr_api.yield_nodes
run_parallel_workers(args.num_node_tasks)
# write edges data
mindrecord_dict_data = mr_api.yield_edges
run_parallel_workers(args.num_edge_tasks)
# writer wrap up
ret = writer.commit()
end_time = time.time()
print("--------------------------------------------")
print("END. Total time: {}".format(end_time - start_time))
print("--------------------------------------------")

@ -66,6 +66,7 @@ set(submodules
$<TARGET_OBJECTS:APItoPython>
$<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
$<TARGET_OBJECTS:engine-gnn>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine>

@ -61,6 +61,7 @@
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
@ -513,6 +514,33 @@ void bindVocabObjects(py::module *m) {
});
}
void bindGraphData(py::module *m) {
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
.def(py::init([](std::string dataset_file, int32_t num_workers) {
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
THROW_IF_ERROR(g_out->Init());
return g_out;
}))
.def("get_nodes",
[](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out));
return out;
})
.def("get_all_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
return out;
})
.def("get_node_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out;
});
}
// This is where we externalize the C logic as python modules
PYBIND11_MODULE(_c_dataengine, m) {
m.doc() = "pybind11 for _c_dataengine";
@ -578,6 +606,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindDatasetOps(&m);
bindInfoObjects(&m);
bindVocabObjects(&m);
bindGraphData(&m);
}
} // namespace dataset
} // namespace mindspore

@ -1,5 +1,6 @@
add_subdirectory(datasetops)
add_subdirectory(opt)
add_subdirectory(gnn)
if (ENABLE_TDTQUE)
add_subdirectory(tdt)
endif ()
@ -15,7 +16,7 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn)
else()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn)
endif ()

@ -112,10 +112,10 @@ Status MindRecordOp::Init() {
data_schema_ = std::make_unique<DataSchema>();
std::vector<std::string> col_names = shard_reader_->get_shard_column()->GetColumnName();
std::vector<std::string> col_names = shard_reader_->GetShardColumn()->GetColumnName();
CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found");
std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->get_shard_column()->GeColumnDataType();
std::vector<std::vector<int64_t>> col_shapes = shard_reader_->get_shard_column()->GetColumnShape();
std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->GetShardColumn()->GeColumnDataType();
std::vector<std::vector<int64_t>> col_shapes = shard_reader_->GetShardColumn()->GetColumnShape();
bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything
std::map<std::string, int32_t> colname_to_ind;
@ -296,8 +296,7 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
std::vector<int64_t> column_shape;
// Get column data
auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName(
auto has_column = shard_reader_->GetShardColumn()->GetColumnValueByName(
column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size,
&column_shape);
if (has_column == MSRStatus::FAILED) {

@ -0,0 +1,7 @@
add_library(engine-gnn OBJECT
graph.cc
graph_loader.cc
local_node.cc
local_edge.cc
feature.cc
)

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_GNN_EDGE_H_
#define DATASET_ENGINE_GNN_EDGE_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include "dataset/util/status.h"
#include "dataset/engine/gnn/feature.h"
#include "dataset/engine/gnn/node.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using EdgeType = int8_t;
using EdgeIdType = int32_t;
class Edge {
public:
// Constructor
// @param EdgeIdType id - edge id
// @param EdgeType type - edge type
// @param std::shared_ptr<Node> src_node - source node
// @param std::shared_ptr<Node> dst_node - destination node
Edge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
: id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {}
virtual ~Edge() = default;
// @return NodeIdType - Returned edge id
EdgeIdType id() const { return id_; }
// @return NodeIdType - Returned edge type
EdgeType type() const { return type_; }
// Get the feature of a edge
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
// Get nodes on the edge
// @param std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> *out_node - Source and destination nodes returned
Status GetNode(std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> *out_node) {
*out_node = std::make_pair(src_node_, dst_node_);
return Status::OK();
}
// Set node to edge
// @param const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> &in_node -
Status SetNode(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> &in_node) {
src_node_ = in_node.first;
dst_node_ = in_node.second;
return Status::OK();
}
// Update feature of edge
// @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;
protected:
EdgeIdType id_;
EdgeType type_;
std::shared_ptr<Node> src_node_;
std::shared_ptr<Node> dst_node_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_GNN_EDGE_H_

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/gnn/feature.h"
namespace mindspore {
namespace dataset {
namespace gnn {
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_GNN_FEATURE_H_
#define DATASET_ENGINE_GNN_FEATURE_H_
#include <memory>
#include "dataset/core/tensor.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using FeatureType = int16_t;
class Feature {
public:
// Constructor
// @param FeatureType type_name - feature type
// @param std::shared_ptr<Tensor> value - feature value
Feature(FeatureType type_name, std::shared_ptr<Tensor> value);
// Get feature value
// @return std::shared_ptr<Tensor> *out_value - feature value
const std::shared_ptr<Tensor> Value() const { return value_; }
// @return NodeIdType - Returned feature type
FeatureType type() const { return type_name_; }
private:
FeatureType type_name_;
std::shared_ptr<Tensor> value_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_GNN_FEATURE_H_

File diff suppressed because it is too large Load Diff

@ -0,0 +1,166 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_GNN_GRAPH_H_
#define DATASET_ENGINE_GNN_GRAPH_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/engine/gnn/graph_loader.h"
#include "dataset/engine/gnn/feature.h"
#include "dataset/engine/gnn/node.h"
#include "dataset/engine/gnn/edge.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
struct NodeMetaInfo {
NodeType type;
NodeIdType num;
std::vector<FeatureType> feature_type;
NodeMetaInfo() {
type = 0;
num = 0;
}
};
struct EdgeMetaInfo {
EdgeType type;
EdgeIdType num;
std::vector<FeatureType> feature_type;
EdgeMetaInfo() {
type = 0;
num = 0;
}
};
class Graph {
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
Graph(std::string dataset_file, int32_t num_workers);
~Graph() = default;
// Get the nodes from the graph.
// @param NodeType node_type - type of node
// @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out);
// Get the edges from the graph.
// @param NodeType edge_type - type of edge
// @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out);
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out);
Status GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
Status GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q,
NodeIdType default_node, std::shared_ptr<Tensor> *out);
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out);
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edget - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types,
TensorRow *out);
// Get meta information of graph
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of node
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of edge
// @return Status - The error code return
Status GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info);
Status Init();
private:
// Load graph data from mindrecord file
// @return Status - The error code return
Status LoadNodeAndEdge();
// Create Tensor By Vector
// @param std::vector<std::vector<T>> &data -
// @param DataType type -
// @param std::shared_ptr<Tensor> *out -
// @return Status - The error code return
template <typename T>
Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out);
// Complete vector
// @param std::vector<std::vector<T>> *data - To be completed vector
// @param size_t max_size - The size of the completed vector
// @param T default_value - Filled default
// @return Status - The error code return
template <typename T>
Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value);
// Get the default feature of a node
// @param FeatureType feature_type -
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
std::string dataset_file_;
int32_t num_workers_; // The number of worker threads
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
std::unordered_map<EdgeType, std::vector<EdgeIdType>> edge_type_map_;
std::unordered_map<EdgeIdType, std::shared_ptr<Edge>> edge_id_map_;
std::unordered_map<NodeType, std::unordered_set<FeatureType>> node_feature_map_;
std::unordered_map<NodeType, std::unordered_set<FeatureType>> edge_feature_map_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_feature_map_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_GNN_GRAPH_H_

@ -0,0 +1,248 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <future>
#include <tuple>
#include <utility>
#include "dataset/engine/gnn/graph_loader.h"
#include "mindspore/ccsrc/mindrecord/include/shard_error.h"
#include "dataset/engine/gnn/local_edge.h"
#include "dataset/engine/gnn/local_node.h"
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::MSRStatus;
GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
: mr_path_(mr_filepath),
num_workers_(num_workers),
row_id_(0),
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) {
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
while (dq.empty() == false) {
std::shared_ptr<Node> node_ptr = dq.front();
n_id_map->insert({node_ptr->id(), node_ptr});
(*n_type_map)[node_ptr->type()].push_back(node_ptr->id());
dq.pop_front();
}
}
for (std::deque<std::shared_ptr<Edge>> &dq : e_deques_) {
while (dq.empty() == false) {
std::shared_ptr<Edge> edge_ptr = dq.front();
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
(*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id());
dq.pop_front();
}
}
for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map);
return Status::OK();
}
Status GraphLoader::InitAndLoad() {
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n");
CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n");
n_deques_.resize(num_workers_);
e_deques_.resize(num_workers_);
n_feature_maps_.resize(num_workers_);
e_feature_maps_.resize(num_workers_);
default_feature_maps_.resize(num_workers_);
std::vector<std::future<Status>> r_codes(num_workers_);
shard_reader_ = std::make_unique<ShardReader>();
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
"Fail to open" + mr_path_);
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"];
for (const std::string &key : keys_) {
if (schema.find(key) == schema.end()) {
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
}
}
// launching worker threads
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id);
}
// wait for threads to finish and check its return code
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
RETURN_IF_NOT_OK(r_codes[wkr_id].get());
}
return Status::OK();
}
Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
std::shared_ptr<Node> *node, NodeFeatureMap *feature_map,
DefaultFeatureMap *default_feature) {
NodeIdType node_id = col_jsn["first_id"];
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
(*node) = std::make_shared<LocalNode>(node_id, node_type);
std::vector<int32_t> indices;
RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices));
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[node_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
return Status::OK();
}
Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map,
DefaultFeatureMap *default_feature) {
EdgeIdType edge_id = col_jsn["first_id"];
EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1);
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
std::vector<int32_t> indices;
RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices));
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[edge_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
return Status::OK();
}
Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible,
std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data));
return Status::OK();
}
Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
for (int i = 0; i < n_bytes; i += col_type_size) {
int32_t feature_ind = -1;
if (col_type == mindrecord::ColumnInt32) {
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
} else if (col_type == mindrecord::ColumnInt64) {
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
} else {
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
}
if (feature_ind >= 0) indices->push_back(feature_ind);
}
return Status::OK();
}
Status GraphLoader::WorkerEntry(int32_t worker_id) {
ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id);
while (rows.empty() == false) {
for (const auto &tupled_row : rows) {
std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
mindrecord::json col_jsn = std::get<1>(tupled_row);
std::string attr = col_jsn["attribute"];
if (attr == "n") {
std::shared_ptr<Node> node_ptr;
RETURN_IF_NOT_OK(
LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
n_deques_[worker_id].emplace_back(node_ptr);
} else if (attr == "e") {
std::shared_ptr<Edge> edge_ptr;
RETURN_IF_NOT_OK(
LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
e_deques_[worker_id].emplace_back(edge_ptr);
} else {
MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node.";
}
}
rows = shard_reader_->GetNextById(row_id_++, worker_id);
}
return Status::OK();
}
void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
DefaultFeatureMap *default_feature_map) {
for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
for (auto &m : n_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
}
for (auto &m : e_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
}
for (auto &m : default_feature_maps_[wkr_id]) {
(*default_feature_map)[m.first] = m.second;
}
}
n_feature_maps_.clear();
e_feature_maps_.clear();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,126 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_
#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_
#include <deque>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "dataset/core/data_type.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/gnn/feature.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/engine/gnn/node.h"
#include "dataset/engine/gnn/edge.h"
#include "dataset/util/status.h"
#include "mindrecord/include/shard_reader.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::ShardReader;
using NodeIdMap = std::unordered_map<NodeIdType, std::shared_ptr<Node>>;
using EdgeIdMap = std::unordered_map<EdgeIdType, std::shared_ptr<Edge>>;
using NodeTypeMap = std::unordered_map<NodeType, std::vector<NodeIdType>>;
using EdgeTypeMap = std::unordered_map<EdgeType, std::vector<EdgeIdType>>;
using NodeFeatureMap = std::unordered_map<NodeType, std::unordered_set<FeatureType>>;
using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureType>>;
using DefaultFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
// this class interfaces with the underlying storage format (mindrecord)
// it returns raw nodes and edges via GetNodesAndEdges
// it is then the responsibility of graph to construct itself based on the nodes and edges
// if needed, this class could become a base where each derived class handles a specific storage format
class GraphLoader {
public:
explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4);
// Init mindrecord and load everything into memory multi-threaded
// @return Status - the status code
Status InitAndLoad();
// this function will query mindrecord and construct all nodes and edges
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
// features attached to each node and edge are expected to be filled correctly
Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *,
DefaultFeatureMap *);
private:
//
// worker thread that reads mindrecord file
// @param int32_t worker_id - id of each worker
// @return Status - the status code
Status WorkerEntry(int32_t worker_id);
// Load a node based on 1 row of mindrecord, returns a shared_ptr<Node>
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Node> *node - return value
// @param NodeFeatureMap *feature_map -
// @param DefaultFeatureMap *default_feature -
// @return Status - the status code
Status LoadNode(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Node> *node,
NodeFeatureMap *feature_map, DefaultFeatureMap *default_feature);
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Edge> *edge - return value, the edge ptr, edge is not yet connected
// @param FeatureMap *feature_map
// @param DefaultFeatureMap *default_feature -
// @return Status - the status code
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
EdgeFeatureMap *feature_map, DefaultFeatureMap *default_feature);
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::vector<int32_t> *ind);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::shared_ptr<Tensor> *tensor);
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultFeatureMap *);
const int32_t num_workers_;
std::atomic_int row_id_;
std::string mr_path_;
std::unique_ptr<ShardReader> shard_reader_;
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
std::vector<NodeFeatureMap> n_feature_maps_;
std::vector<EdgeFeatureMap> e_feature_maps_;
std::vector<DefaultFeatureMap> default_feature_maps_;
const std::vector<std::string> keys_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/gnn/local_edge.h"
#include <string>
namespace mindspore {
namespace dataset {
namespace gnn {
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
: Edge(id, type, src_node, dst_node) {}
Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = features_.find(feature_type);
if (itr != features_.end()) {
*out_feature = itr->second;
return Status::OK();
} else {
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) {
auto itr = features_.find(feature->type());
if (itr != features_.end()) {
RETURN_STATUS_UNEXPECTED("Feature already exists");
} else {
features_[feature->type()] = feature;
return Status::OK();
}
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_
#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include "dataset/util/status.h"
#include "dataset/engine/gnn/edge.h"
#include "dataset/engine/gnn/feature.h"
#include "dataset/engine/gnn/node.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class LocalEdge : public Edge {
public:
// Constructor
// @param EdgeIdType id - edge id
// @param EdgeType type - edge type
// @param std::shared_ptr<Node> src_node - source node
// @param std::shared_ptr<Node> dst_node - destination node
LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node);
~LocalEdge() = default;
// Get the feature of a edge
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
// Update feature of edge
// @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
private:
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save