Gnn data processing supports distributed scenarios

pull/4498/head
heleiwang 5 years ago
parent 1ca715c7e7
commit 8ee4d8e92d

@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
SET(MS_BUILD_GRPC 0)
if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES)
SET(MS_BUILD_GRPC 1)
endif()
if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
SET(MS_BUILD_GRPC 1)
endif()
if ("${MS_BUILD_GRPC}")
# build dependencies of gRPC
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake)

@ -83,6 +83,7 @@ endif()
if (ENABLE_TDTQUE)
add_dependencies(engine-tdt core)
endif ()
################### Create _c_dataengine Library ######################
set(submodules
$<TARGET_OBJECTS:core>
@ -182,3 +183,7 @@ else()
set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON)
endif ()
endif()
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++)
endif()

@ -18,83 +18,103 @@
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/gnn/graph_data_client.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/graph_data_server.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
Graph, 0, ([](const py::module *m) {
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
.def(py::init([](std::string dataset_file, int32_t num_workers) {
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
THROW_IF_ERROR(g_out->Init());
return g_out;
(void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode,
const std::string &hostname, int32_t port) {
std::shared_ptr<gnn::GraphData> out;
if (working_mode == "local") {
out = std::make_shared<gnn::GraphDataImpl>(dataset_file, num_workers);
} else if (working_mode == "client") {
out = std::make_shared<gnn::GraphDataClient>(dataset_file, hostname, port);
}
THROW_IF_ERROR(out->Init());
return out;
}))
.def("get_all_nodes",
[](gnn::Graph &g, gnn::NodeType node_type) {
[](gnn::GraphData &g, gnn::NodeType node_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
return out;
})
.def("get_all_edges",
[](gnn::Graph &g, gnn::EdgeType edge_type) {
[](gnn::GraphData &g, gnn::EdgeType edge_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
return out;
})
.def("get_nodes_from_edges",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out;
})
.def("get_all_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
return out;
})
.def("get_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
return out;
})
.def("get_neg_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
gnn::NodeType neg_neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
return out;
})
.def("get_node_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
[](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out.getRow();
})
.def("get_edge_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
[](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
return out.getRow();
})
.def("graph_info",
[](gnn::Graph &g) {
[](gnn::GraphData &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
return out;
})
.def("random_walk",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out;
});
})
.def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); });
(void)py::class_<gnn::GraphDataServer, std::shared_ptr<gnn::GraphDataServer>>(*m, "GraphDataServer")
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
int32_t client_num, bool auto_shutdown) {
std::shared_ptr<gnn::GraphDataServer> out;
out =
std::make_shared<gnn::GraphDataServer>(dataset_file, num_workers, hostname, port, client_num, auto_shutdown);
THROW_IF_ERROR(out->Init());
return out;
}))
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
.def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); });
}));
} // namespace dataset

@ -1,9 +1,29 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-gnn OBJECT
graph.cc
set(DATASET_ENGINE_GNN_SRC_FILES
graph_data_impl.cc
graph_data_client.cc
graph_data_server.cc
graph_loader.cc
graph_feature_parser.cc
local_node.cc
local_edge.cc
feature.cc
)
)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES})
else()
set(DATASET_ENGINE_GNN_SRC_FILES
${DATASET_ENGINE_GNN_SRC_FILES}
tensor_proto.cc
grpc_async_server.cc
graph_data_service_impl.cc
graph_shared_memory.cc)
ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto")
ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto")
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS})
add_dependencies(engine-gnn mindspore::protobuf)
endif()

@ -19,7 +19,8 @@ namespace mindspore {
namespace dataset {
namespace gnn {
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {}
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory)
: type_name_(type_name), value_(value), is_shared_memory_(is_shared_memory) {}
} // namespace gnn
} // namespace dataset

@ -31,7 +31,7 @@ class Feature {
// Constructor
// @param FeatureType type_name - feature type
// @param std::shared_ptr<Tensor> value - feature value
Feature(FeatureType type_name, std::shared_ptr<Tensor> value);
Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory = false);
~Feature() = default;
@ -45,6 +45,7 @@ class Feature {
private:
FeatureType type_name_;
std::shared_ptr<Tensor> value_;
bool is_shared_memory_;
};
} // namespace gnn
} // namespace dataset

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
import "gnn_tensor.proto";
message GnnClientRegisterRequestPb {
int32 pid = 1;
}
message GnnFeatureInfoPb {
int32 type = 1;
TensorPb feature = 2;
}
message GnnClientRegisterResponsePb {
string error_msg = 1;
string data_schema = 2;
int64 shared_memory_key = 3;
int64 shared_memory_size = 4;
repeated GnnFeatureInfoPb default_node_feature = 5;
repeated GnnFeatureInfoPb default_edge_feature = 6;
}
message GnnClientUnRegisterRequestPb {
int32 pid = 1;
}
message GnnClientUnRegisterResponsePb {
string error_msg = 1;
}
enum GnnOpName {
GET_ALL_NODES = 0;
GET_ALL_EDGES = 1;
GET_NODES_FROM_EDGES = 2;
GET_ALL_NEIGHBORS = 3;
GET_SAMPLED_NEIGHBORS = 4;
GET_NEG_SAMPLED_NEIGHBORS = 5;
RANDOM_WALK = 6;
GET_NODE_FEATURE = 7;
GET_EDGE_FEATURE = 8;
}
message GnnRandomWalkPb {
float p = 1;
float q = 2;
int32 default_id = 3;
}
message GnnGraphDataRequestPb {
GnnOpName op_name = 1;
repeated int32 id = 2; // node id or edge id
repeated int32 type = 3; //node type or edge type or neighbor type or feature type
repeated int32 number = 4; // samples number
TensorPb id_tensor = 5; // input ids ,node id or edge id
GnnRandomWalkPb random_walk = 6;
}
message GnnGraphDataResponsePb {
string error_msg = 1;
repeated TensorPb result_data = 2;
}
message GnnMetaInfoRequestPb {
}
message GnnNodeEdgeInfoPb {
int32 type = 1;
int32 num = 2;
}
message GnnMetaInfoResponsePb {
string error_msg = 1;
repeated GnnNodeEdgeInfoPb node_info = 2;
repeated GnnNodeEdgeInfoPb edge_info = 3;
repeated int32 node_feature_type = 4;
repeated int32 edge_feature_type = 5;
}
service GnnGraphData {
rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb);
rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb);
rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb);
rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb);
}

@ -0,0 +1,42 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
enum DataTypePb {
DE_PB_UNKNOWN = 0;
DE_PB_BOOL = 1;
DE_PB_INT8 = 2;
DE_PB_UINT8 = 3;
DE_PB_INT16 = 4;
DE_PB_UINT16 = 5;
DE_PB_INT32 = 6;
DE_PB_UINT32 = 7;
DE_PB_INT64 = 8;
DE_PB_UINT64 = 9;
DE_PB_FLOAT16 = 10;
DE_PB_FLOAT32 = 11;
DE_PB_FLOAT64 = 12;
DE_PB_STRING = 13;
}
message TensorPb {
repeated int64 dims = 1; // tensor shape info
DataTypePb tensor_type = 2; // tensor content data type
bytes data = 3; // tensor data
}

@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
struct MetaInfo {
std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type;
std::map<NodeType, NodeIdType> node_num;
std::map<EdgeType, EdgeIdType> edge_num;
std::vector<FeatureType> node_feature_type;
std::vector<FeatureType> edge_feature_type;
};
class GraphData {
public:
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) = 0;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) = 0;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;
// Return meta information to python layer
virtual Status GraphInfo(py::dict *out) = 0;
virtual Status Init() = 0;
virtual Status Stop() = 0;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_

@ -0,0 +1,185 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#include <algorithm>
#include <memory>
#include <string>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <utility>
#if !defined(_WIN32) && !defined(_WIN64)
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
#endif
#include "minddata/dataset/engine/gnn/graph_data.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/mindrecord/include/common/shard_utils.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataClient : public GraphData {
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port);
~GraphDataClient();
Status Init() override;
Status Stop() override;
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) override;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
// Return meta information to python layer
Status GraphInfo(py::dict *out) override;
private:
#if !defined(_WIN32) && !defined(_WIN64)
Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response);
Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
std::shared_ptr<Tensor> *out);
Status RegisterToServer();
Status UnRegisterToServer();
Status InitFeatureParser();
Status CheckPid() {
CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(),
"Multi-process mode is not supported, please change to use multi-thread");
return Status::OK();
}
#endif
std::string dataset_file_;
std::string host_;
int32_t port_;
int32_t pid_;
mindrecord::json data_schema_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GnnGraphData::Stub> stub_;
key_t shared_memory_key_;
int64_t shared_memory_size_;
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_;
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_;
#endif
bool registered_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
#include <algorithm>
#include <memory>
@ -25,13 +25,11 @@
#include <vector>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/graph_data.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/mindrecord/include/common/shard_utils.h"
namespace mindspore {
namespace dataset {
@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001;
const uint32_t kMaxNumWalks = 80;
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
struct MetaInfo {
std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type;
std::map<NodeType, NodeIdType> node_num;
std::map<EdgeType, EdgeIdType> edge_num;
std::vector<FeatureType> node_feature_type;
std::vector<FeatureType> edge_feature_type;
};
class Graph {
class GraphDataImpl : public GraphData {
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
Graph(std::string dataset_file, int32_t num_workers);
GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false);
~Graph() = default;
~GraphDataImpl();
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out);
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out);
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out);
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
@ -85,7 +74,7 @@ class Graph {
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out);
std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
@ -94,7 +83,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
@ -103,7 +92,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
@ -115,7 +104,7 @@ class Graph {
// @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out);
std::shared_ptr<Tensor> *out) override;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
@ -124,16 +113,22 @@ class Graph {
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out);
TensorRow *out) override;
Status GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type,
std::shared_ptr<Tensor> *out);
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edget - List of edges
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types,
TensorRow *out);
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
Status GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type,
std::shared_ptr<Tensor> *out);
// Get meta information of graph
// @param MetaInfo *meta_info - Returned meta information
@ -142,15 +137,34 @@ class Graph {
#ifdef ENABLE_PYTHON
// Return meta information to python layer
Status GraphInfo(py::dict *out);
Status GraphInfo(py::dict *out) override;
#endif
Status Init();
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() {
return &default_node_feature_map_;
}
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() {
return &default_edge_feature_map_;
}
Status Init() override;
Status Stop() override { return Status::OK(); }
std::string GetDataSchema() { return data_schema_.dump(); }
#if !defined(_WIN32) && !defined(_WIN64)
key_t GetSharedMemoryKey() { return graph_shared_memory_->memory_key(); }
int64_t GetSharedMemorySize() { return graph_shared_memory_->memory_size(); }
#endif
private:
friend class GraphLoader;
class RandomWalkBase {
public:
explicit RandomWalkBase(Graph *graph);
explicit RandomWalkBase(GraphDataImpl *graph);
Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
@ -176,7 +190,7 @@ class Graph {
template <typename T>
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
Graph *graph_;
GraphDataImpl *graph_;
std::vector<NodeIdType> node_list_;
std::vector<NodeType> meta_path_;
float step_home_param_; // Return hyper parameter. Default is 1.0
@ -248,7 +262,11 @@ class Graph {
int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_;
RandomWalkBase random_walk_;
mindrecord::json data_schema_;
bool server_mode_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
#endif
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
@ -264,4 +282,4 @@ class Graph {
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_

@ -0,0 +1,133 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_server.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <utility>
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
namespace gnn {
GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname,
int32_t port, int32_t client_num, bool auto_shutdown)
: dataset_file_(dataset_file),
num_workers_(num_workers),
client_num_(client_num),
max_connected_client_num_(0),
auto_shutdown_(auto_shutdown),
state_(kGdsUninit) {
tg_ = std::make_unique<TaskGroup>();
graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true);
#if !defined(_WIN32) && !defined(_WIN64)
service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get());
async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get());
#endif
}
Status GraphDataServer::Init() {
#if defined(_WIN32) || defined(_WIN64)
RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS");
#else
set_state(kGdsInitializing);
RETURN_IF_NOT_OK(async_server_->Run());
// RETURN_IF_NOT_OK(InitGraphDataImpl());
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this)));
for (int32_t i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(
tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this)));
}
if (auto_shutdown_) {
RETURN_IF_NOT_OK(
tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this)));
}
return Status::OK();
#endif
}
Status GraphDataServer::InitGraphDataImpl() {
TaskManager::FindMe()->Post();
Status s = graph_data_impl_->Init();
if (s.IsOk()) {
set_state(kGdsRunning);
} else {
(void)Stop();
}
return s;
}
#if !defined(_WIN32) && !defined(_WIN64)
Status GraphDataServer::StartAsyncRpcService() {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(async_server_->HandleRequest());
return Status::OK();
}
#endif
Status GraphDataServer::JudgeAutoShutdownServer() {
TaskManager::FindMe()->Post();
while (true) {
if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) {
MS_LOG(INFO) << "All clients have been unregister, automatically exit the server.";
RETURN_IF_NOT_OK(Stop());
break;
}
if (state_ == kGdsStopped) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
return Status::OK();
}
Status GraphDataServer::Stop() {
#if !defined(_WIN32) && !defined(_WIN64)
async_server_->Stop();
#endif
set_state(kGdsStopped);
graph_data_impl_.reset();
return Status::OK();
}
Status GraphDataServer::ClientRegister(int32_t pid) {
std::unique_lock<std::mutex> lck(mutex_);
MS_LOG(INFO) << "client register pid:" << std::to_string(pid);
client_pid_.emplace(pid);
if (client_pid_.size() > max_connected_client_num_) {
max_connected_client_num_ = client_pid_.size();
}
return Status::OK();
}
Status GraphDataServer::ClientUnRegister(int32_t pid) {
std::unique_lock<std::mutex> lck(mutex_);
auto itr = client_pid_.find(pid);
if (itr != client_pid_.end()) {
client_pid_.erase(itr);
MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid);
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,196 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#include <memory>
#include <mutex>
#include <string>
#include <unordered_set>
#if !defined(_WIN32) && !defined(_WIN64)
#include "grpcpp/grpcpp.h"
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#endif
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataImpl;
class GraphDataServer {
public:
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
int32_t client_num, bool auto_shutdown);
~GraphDataServer() = default;
Status Init();
Status Stop();
Status ClientRegister(int32_t pid);
Status ClientUnRegister(int32_t pid);
enum ServerState state() { return state_; }
bool IsStoped() {
if (state_ == kGdsStopped) {
return true;
} else {
return false;
}
}
private:
void set_state(enum ServerState state) { state_ = state; }
Status InitGraphDataImpl();
#if !defined(_WIN32) && !defined(_WIN64)
Status StartAsyncRpcService();
#endif
Status JudgeAutoShutdownServer();
std::string dataset_file_;
int32_t num_workers_; // The number of worker threads
int32_t client_num_;
int32_t max_connected_client_num_;
bool auto_shutdown_;
enum ServerState state_;
std::unique_ptr<TaskGroup> tg_; // Class for worker management
std::unique_ptr<GraphDataImpl> graph_data_impl_;
std::unordered_set<int32_t> client_pid_;
std::mutex mutex_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GraphDataServiceImpl> service_impl_;
std::unique_ptr<GrpcAsyncServer> async_server_;
#endif
};
#if !defined(_WIN32) && !defined(_WIN64)
class UntypedCall {
public:
virtual ~UntypedCall() {}
virtual Status operator()() = 0;
};
template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage>
class CallData : public UntypedCall {
public:
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *,
grpc::ServerAsyncResponseWriter<ResponseMessage> *,
grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *,
ResponseMessage *);
CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function)
: status_(STATE::CREATE),
service_impl_(service_impl),
async_service_(async_service),
cq_(cq),
enqueue_function_(enqueue_function),
handle_request_function_(handle_request_function),
responder_(&ctx_) {}
~CallData() = default;
static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) {
auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>(
service_impl, async_service, cq, enqueue_function, handle_request_function);
RETURN_IF_NOT_OK((*call)());
return Status::OK();
}
Status operator()() {
if (status_ == STATE::CREATE) {
status_ = STATE::PROCESS;
(async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this);
} else if (status_ == STATE::PROCESS) {
EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_);
status_ = STATE::FINISH;
// new CallData(service_, cq_, this->s_type_);
grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_);
responder_.Finish(response_, s, this);
} else {
GPR_ASSERT(status_ == STATE::FINISH);
delete this;
}
return Status::OK();
}
private:
STATE status_;
ServiceImpl *service_impl_;
AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
EnqueueFunction enqueue_function_;
HandleRequestFunction handle_request_function_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
RequestMessage request_;
ResponseMessage response_;
};
#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \
do { \
Status s = \
CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \
service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \
&gnn::GraphDataServiceImpl::method); \
RETURN_IF_NOT_OK(s); \
} while (0)
class GraphDataGrpcServer : public GrpcAsyncServer {
public:
GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl)
: GrpcAsyncServer(host, port), service_impl_(service_impl) {}
Status RegisterService(grpc::ServerBuilder *builder) {
builder->RegisterService(&svc_);
return Status::OK();
}
Status EnqueueRequest() {
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb,
GnnClientRegisterResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb,
GnnClientUnRegisterResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb);
return Status::OK();
}
Status ProcessRequest(void *tag) {
auto rq = static_cast<UntypedCall *>(tag);
RETURN_IF_NOT_OK((*rq)());
return Status::OK();
}
private:
GraphDataServiceImpl *service_impl_;
GnnGraphData::AsyncService svc_;
};
#endif
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#include <memory>
#include <string>
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataServer;
// class GraphDataServiceImpl : public GnnGraphData::Service {
class GraphDataServiceImpl {
public:
GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl);
~GraphDataServiceImpl() = default;
grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request,
GnnClientRegisterResponsePb *response);
grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request,
GnnClientUnRegisterResponsePb *response);
grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response);
grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
GnnMetaInfoResponsePb *response);
Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
private:
Status FillDefaultFeature(GnnClientRegisterResponsePb *response);
GraphDataServer *server_;
GraphDataImpl *graph_data_impl_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_

@ -0,0 +1,106 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#include <memory>
#include <utility>
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::MSRStatus;
GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) {
shard_column_ = std::make_unique<ShardColumn>(shard_column);
}
Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
std::shared_ptr<Tensor> *tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
data, tensor));
return Status::OK();
}
#if !defined(_WIN32) && !defined(_WIN64)
Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
GraphSharedMemory *shared_memory,
std::shared_ptr<Tensor> *out_tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
auto fea_itr = tensor->begin<int64_t>();
int64_t offset = 0;
RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset));
*fea_itr = offset;
++fea_itr;
*fea_itr = n_bytes;
*out_tensor = std::move(tensor);
return Status::OK();
}
#endif
Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
std::vector<int32_t> *indices) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
for (int i = 0; i < n_bytes; i += col_type_size) {
int32_t feature_ind = -1;
if (col_type == mindrecord::ColumnInt32) {
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
} else if (col_type == mindrecord::ColumnInt64) {
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
} else {
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
}
if (feature_ind >= 0) indices->push_back(feature_ind);
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::ShardColumn;
class GraphFeatureParser {
public:
explicit GraphFeatureParser(const ShardColumn &shard_column);
~GraphFeatureParser() = default;
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor);
#if !defined(_WIN32) && !defined(_WIN64)
Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor);
#endif
private:
std::unique_ptr<ShardColumn> shard_column_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_

File diff suppressed because it is too large Load Diff

@ -26,10 +26,13 @@
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_reader.h"
namespace mindspore {
@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureTy
using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
class GraphDataImpl;
// this class interfaces with the underlying storage format (mindrecord)
// it returns raw nodes and edges via GetNodesAndEdges
// it is then the responsibility of graph to construct itself based on the nodes and edges
// if needed, this class could become a base where each derived class handles a specific storage format
class GraphLoader {
public:
explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4);
GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false);
~GraphLoader() = default;
// Init mindrecord and load everything into memory multi-threaded
@ -63,8 +68,7 @@ class GraphLoader {
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
// features attached to each node and edge are expected to be filled correctly
Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *,
DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
Status GetNodesAndEdges();
private:
//
@ -92,29 +96,15 @@ class GraphLoader {
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::vector<int32_t> *ind);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::shared_ptr<Tensor> *tensor);
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
void MergeFeatureMaps();
GraphDataImpl *graph_impl_;
std::string mr_path_;
const int32_t num_workers_;
std::atomic_int row_id_;
std::string mr_path_;
std::unique_ptr<ShardReader> shard_reader_;
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
std::vector<NodeFeatureMap> n_feature_maps_;

@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#include <string>
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
namespace gnn {
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key)
: memory_size_(memory_size),
memory_key_(memory_key),
memory_ptr_(nullptr),
memory_offset_(0),
is_new_create_(false) {
std::stringstream stream;
stream << std::hex << memory_key_;
memory_key_str_ = stream.str();
}
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file)
: mr_file_(mr_file),
memory_size_(memory_size),
memory_key_(-1),
memory_ptr_(nullptr),
memory_offset_(0),
is_new_create_(false) {}
GraphSharedMemory::~GraphSharedMemory() {
if (is_new_create_) {
(void)DeleteSharedMemory();
}
}
Status GraphSharedMemory::CreateSharedMemory() {
if (memory_key_ == -1) {
// ftok to generate unique key
memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId);
CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_);
std::stringstream stream;
stream << std::hex << memory_key_;
memory_key_str_ = stream.str();
}
int shmflg = (0666 | IPC_CREAT | IPC_EXCL);
Status s = SharedMemoryImpl(shmflg);
if (s.IsOk()) {
is_new_create_ = true;
MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_;
} else {
MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_;
shmflg = (0666 | IPC_CREAT);
s = SharedMemoryImpl(shmflg);
if (!s.IsOk()) {
RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_);
}
}
return Status::OK();
}
Status GraphSharedMemory::GetSharedMemory() {
int shmflg = 0;
RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg));
return Status::OK();
}
Status GraphSharedMemory::DeleteSharedMemory() {
int shmid = shmget(memory_key_, 0, 0);
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
int result = shmctl(shmid, IPC_RMID, 0);
CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_);
return Status::OK();
}
Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) {
// shmget returns an identifier in shmid
int shmid = shmget(memory_key_, memory_size_, shmflg);
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
// shmat to attach to shared memory
auto data = shmat(shmid, reinterpret_cast<void *>(0), 0);
CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_);
memory_ptr_ = reinterpret_cast<uint8_t *>(data);
return Status::OK();
}
Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) {
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid.");
std::lock_guard<std::mutex> lck(mutex_);
CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len),
"Insufficient shared memory space to insert data.");
if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) {
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
}
*offset = memory_offset_;
memory_offset_ += len;
return Status::OK();
}
Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) {
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid.");
CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space.");
CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset,
"get_data_len is too large, beyond the space of shared memory.");
if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) {
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#include <sys/ipc.h>
#include <sys/shm.h>
#include <mutex>
#include <string>
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
const int kGnnSharedMemoryId = 65;
class GraphSharedMemory {
public:
explicit GraphSharedMemory(int64_t memory_size, key_t memory_key);
explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file);
~GraphSharedMemory();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status CreateSharedMemory();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status GetSharedMemory();
Status DeleteSharedMemory();
Status InsertData(const uint8_t *data, int64_t len, int64_t *offset);
Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len);
key_t memory_key() { return memory_key_; }
int64_t memory_size() { return memory_size_; }
private:
Status SharedMemoryImpl(const int &shmflg);
std::string mr_file_;
int64_t memory_size_;
key_t memory_key_;
std::string memory_key_str_;
uint8_t *memory_ptr_;
int64_t memory_offset_;
std::mutex mutex_;
bool is_new_create_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#include <limits>
#include "minddata/dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {}
GrpcAsyncServer::~GrpcAsyncServer() { Stop(); }
Status GrpcAsyncServer::Run() {
std::string server_address = host_ + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
int port_tcpip = 0;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
RETURN_IF_NOT_OK(RegisterService(&builder));
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
errMsg += "Unable to bind to address " + server_address + ".";
}
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status GrpcAsyncServer::HandleRequest() {
bool success;
void *tag;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CallData
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
RETURN_IF_NOT_OK(EnqueueRequest());
while (cq_->Next(&tag, &success)) {
RETURN_IF_INTERRUPTED();
if (success) {
RETURN_IF_NOT_OK(ProcessRequest(tag));
} else {
MS_LOG(DEBUG) << "cq_->Next failed.";
}
}
return Status::OK();
}
void GrpcAsyncServer::Stop() {
if (server_) {
server_->Shutdown();
}
// Always shutdown the completion queue after the server.
if (cq_) {
cq_->Shutdown();
}
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "grpcpp/grpcpp.h"
#include "grpcpp/impl/codegen/async_unary_call.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief Async server base class
class GrpcAsyncServer {
public:
explicit GrpcAsyncServer(const std::string &host, int32_t port);
virtual ~GrpcAsyncServer();
/// \brief Brings up gRPC server
/// \return none
Status Run();
/// \brief Entry function to handle async server request
Status HandleRequest();
void Stop();
virtual Status RegisterService(grpc::ServerBuilder *builder) = 0;
virtual Status EnqueueRequest() = 0;
virtual Status ProcessRequest(void *tag) = 0;
protected:
int32_t port_;
std::string host_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save