parent
1ca715c7e7
commit
8ee4d8e92d
@ -1,9 +1,29 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-gnn OBJECT
|
||||
graph.cc
|
||||
set(DATASET_ENGINE_GNN_SRC_FILES
|
||||
graph_data_impl.cc
|
||||
graph_data_client.cc
|
||||
graph_data_server.cc
|
||||
graph_loader.cc
|
||||
graph_feature_parser.cc
|
||||
local_node.cc
|
||||
local_edge.cc
|
||||
feature.cc
|
||||
)
|
||||
)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES})
|
||||
else()
|
||||
set(DATASET_ENGINE_GNN_SRC_FILES
|
||||
${DATASET_ENGINE_GNN_SRC_FILES}
|
||||
tensor_proto.cc
|
||||
grpc_async_server.cc
|
||||
graph_data_service_impl.cc
|
||||
graph_shared_memory.cc)
|
||||
|
||||
ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto")
|
||||
ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto")
|
||||
|
||||
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS})
|
||||
add_dependencies(engine-gnn mindspore::protobuf)
|
||||
endif()
|
||||
|
@ -0,0 +1,103 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mindspore.dataset;
|
||||
|
||||
import "gnn_tensor.proto";
|
||||
|
||||
message GnnClientRegisterRequestPb {
|
||||
int32 pid = 1;
|
||||
}
|
||||
|
||||
message GnnFeatureInfoPb {
|
||||
int32 type = 1;
|
||||
TensorPb feature = 2;
|
||||
}
|
||||
|
||||
message GnnClientRegisterResponsePb {
|
||||
string error_msg = 1;
|
||||
string data_schema = 2;
|
||||
int64 shared_memory_key = 3;
|
||||
int64 shared_memory_size = 4;
|
||||
repeated GnnFeatureInfoPb default_node_feature = 5;
|
||||
repeated GnnFeatureInfoPb default_edge_feature = 6;
|
||||
}
|
||||
|
||||
message GnnClientUnRegisterRequestPb {
|
||||
int32 pid = 1;
|
||||
}
|
||||
|
||||
message GnnClientUnRegisterResponsePb {
|
||||
string error_msg = 1;
|
||||
}
|
||||
|
||||
enum GnnOpName {
|
||||
GET_ALL_NODES = 0;
|
||||
GET_ALL_EDGES = 1;
|
||||
GET_NODES_FROM_EDGES = 2;
|
||||
GET_ALL_NEIGHBORS = 3;
|
||||
GET_SAMPLED_NEIGHBORS = 4;
|
||||
GET_NEG_SAMPLED_NEIGHBORS = 5;
|
||||
RANDOM_WALK = 6;
|
||||
GET_NODE_FEATURE = 7;
|
||||
GET_EDGE_FEATURE = 8;
|
||||
}
|
||||
|
||||
message GnnRandomWalkPb {
|
||||
float p = 1;
|
||||
float q = 2;
|
||||
int32 default_id = 3;
|
||||
}
|
||||
|
||||
message GnnGraphDataRequestPb {
|
||||
GnnOpName op_name = 1;
|
||||
repeated int32 id = 2; // node id or edge id
|
||||
repeated int32 type = 3; //node type or edge type or neighbor type or feature type
|
||||
repeated int32 number = 4; // samples number
|
||||
TensorPb id_tensor = 5; // input ids ,node id or edge id
|
||||
GnnRandomWalkPb random_walk = 6;
|
||||
}
|
||||
|
||||
message GnnGraphDataResponsePb {
|
||||
string error_msg = 1;
|
||||
repeated TensorPb result_data = 2;
|
||||
}
|
||||
|
||||
message GnnMetaInfoRequestPb {
|
||||
|
||||
}
|
||||
|
||||
message GnnNodeEdgeInfoPb {
|
||||
int32 type = 1;
|
||||
int32 num = 2;
|
||||
}
|
||||
|
||||
message GnnMetaInfoResponsePb {
|
||||
string error_msg = 1;
|
||||
repeated GnnNodeEdgeInfoPb node_info = 2;
|
||||
repeated GnnNodeEdgeInfoPb edge_info = 3;
|
||||
repeated int32 node_feature_type = 4;
|
||||
repeated int32 edge_feature_type = 5;
|
||||
}
|
||||
|
||||
service GnnGraphData {
|
||||
rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb);
|
||||
rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb);
|
||||
rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb);
|
||||
rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb);
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mindspore.dataset;
|
||||
|
||||
enum DataTypePb {
|
||||
DE_PB_UNKNOWN = 0;
|
||||
DE_PB_BOOL = 1;
|
||||
DE_PB_INT8 = 2;
|
||||
DE_PB_UINT8 = 3;
|
||||
DE_PB_INT16 = 4;
|
||||
DE_PB_UINT16 = 5;
|
||||
DE_PB_INT32 = 6;
|
||||
DE_PB_UINT32 = 7;
|
||||
DE_PB_INT64 = 8;
|
||||
DE_PB_UINT64 = 9;
|
||||
DE_PB_FLOAT16 = 10;
|
||||
DE_PB_FLOAT32 = 11;
|
||||
DE_PB_FLOAT64 = 12;
|
||||
DE_PB_STRING = 13;
|
||||
}
|
||||
|
||||
message TensorPb {
|
||||
repeated int64 dims = 1; // tensor shape info
|
||||
DataTypePb tensor_type = 2; // tensor content data type
|
||||
bytes data = 3; // tensor data
|
||||
}
|
@ -0,0 +1,134 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/engine/gnn/feature.h"
|
||||
#include "minddata/dataset/engine/gnn/node.h"
|
||||
#include "minddata/dataset/engine/gnn/edge.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
struct MetaInfo {
|
||||
std::vector<NodeType> node_type;
|
||||
std::vector<EdgeType> edge_type;
|
||||
std::map<NodeType, NodeIdType> node_num;
|
||||
std::map<EdgeType, EdgeIdType> edge_num;
|
||||
std::vector<FeatureType> node_feature_type;
|
||||
std::vector<FeatureType> edge_feature_type;
|
||||
};
|
||||
|
||||
class GraphData {
|
||||
public:
|
||||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
// @return Status - The error code return
|
||||
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get all edges from the graph.
|
||||
// @param NodeType edge_type - type of edge
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The error code return
|
||||
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get the node id from the edge.
|
||||
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||
// @return Status - The error code return
|
||||
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status - The error code return
|
||||
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status - The error code return
|
||||
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||
const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get negative sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||
// @return Status - The error code return
|
||||
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Node2vec random walk.
|
||||
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||
// @param std::vector<NodeType> meta_path - node type of each step
|
||||
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status - The error code return
|
||||
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||
std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get the feature of a node
|
||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status - The error code return
|
||||
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) = 0;
|
||||
|
||||
// Get the feature of a edge
|
||||
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status - The error code return
|
||||
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) = 0;
|
||||
|
||||
// Return meta information to python layer
|
||||
virtual Status GraphInfo(py::dict *out) = 0;
|
||||
|
||||
virtual Status Init() = 0;
|
||||
|
||||
virtual Status Stop() = 0;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,185 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include "proto/gnn_graph_data.grpc.pb.h"
|
||||
#include "proto/gnn_graph_data.pb.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/gnn/graph_data.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||
#endif
|
||||
#include "minddata/mindrecord/include/common/shard_utils.h"
|
||||
#include "minddata/mindrecord/include/shard_column.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
class GraphDataClient : public GraphData {
|
||||
public:
|
||||
// Constructor
|
||||
// @param std::string dataset_file -
|
||||
// @param int32_t num_workers - number of parallel threads
|
||||
GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port);
|
||||
|
||||
~GraphDataClient();
|
||||
|
||||
Status Init() override;
|
||||
|
||||
Status Stop() override;
|
||||
|
||||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
// @return Status - The error code return
|
||||
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get all edges from the graph.
|
||||
// @param NodeType edge_type - type of edge
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The error code return
|
||||
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get the node id from the edge.
|
||||
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||
// @return Status - The error code return
|
||||
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status - The error code return
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status - The error code return
|
||||
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get negative sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||
// @return Status - The error code return
|
||||
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Node2vec random walk.
|
||||
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||
// @param std::vector<NodeType> meta_path - node type of each step
|
||||
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status - The error code return
|
||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get the feature of a node
|
||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status - The error code return
|
||||
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) override;
|
||||
|
||||
// Get the feature of a edge
|
||||
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status - The error code return
|
||||
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) override;
|
||||
|
||||
// Return meta information to python layer
|
||||
Status GraphInfo(py::dict *out) override;
|
||||
|
||||
private:
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
|
||||
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
|
||||
|
||||
Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
|
||||
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
|
||||
|
||||
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
|
||||
|
||||
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
|
||||
|
||||
Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response);
|
||||
|
||||
Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
|
||||
std::shared_ptr<Tensor> *out);
|
||||
|
||||
Status RegisterToServer();
|
||||
|
||||
Status UnRegisterToServer();
|
||||
|
||||
Status InitFeatureParser();
|
||||
|
||||
Status CheckPid() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(),
|
||||
"Multi-process mode is not supported, please change to use multi-thread");
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::string dataset_file_;
|
||||
std::string host_;
|
||||
int32_t port_;
|
||||
int32_t pid_;
|
||||
mindrecord::json data_schema_;
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
std::unique_ptr<GnnGraphData::Stub> stub_;
|
||||
key_t shared_memory_key_;
|
||||
int64_t shared_memory_size_;
|
||||
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
|
||||
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_;
|
||||
#endif
|
||||
bool registered_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,133 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/gnn/graph_data_server.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname,
|
||||
int32_t port, int32_t client_num, bool auto_shutdown)
|
||||
: dataset_file_(dataset_file),
|
||||
num_workers_(num_workers),
|
||||
client_num_(client_num),
|
||||
max_connected_client_num_(0),
|
||||
auto_shutdown_(auto_shutdown),
|
||||
state_(kGdsUninit) {
|
||||
tg_ = std::make_unique<TaskGroup>();
|
||||
graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true);
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get());
|
||||
async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get());
|
||||
#endif
|
||||
}
|
||||
|
||||
Status GraphDataServer::Init() {
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS");
|
||||
#else
|
||||
set_state(kGdsInitializing);
|
||||
RETURN_IF_NOT_OK(async_server_->Run());
|
||||
// RETURN_IF_NOT_OK(InitGraphDataImpl());
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this)));
|
||||
for (int32_t i = 0; i < num_workers_; ++i) {
|
||||
RETURN_IF_NOT_OK(
|
||||
tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this)));
|
||||
}
|
||||
if (auto_shutdown_) {
|
||||
RETURN_IF_NOT_OK(
|
||||
tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this)));
|
||||
}
|
||||
return Status::OK();
|
||||
#endif
|
||||
}
|
||||
|
||||
Status GraphDataServer::InitGraphDataImpl() {
|
||||
TaskManager::FindMe()->Post();
|
||||
Status s = graph_data_impl_->Init();
|
||||
if (s.IsOk()) {
|
||||
set_state(kGdsRunning);
|
||||
} else {
|
||||
(void)Stop();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status GraphDataServer::StartAsyncRpcService() {
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(async_server_->HandleRequest());
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status GraphDataServer::JudgeAutoShutdownServer() {
|
||||
TaskManager::FindMe()->Post();
|
||||
while (true) {
|
||||
if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) {
|
||||
MS_LOG(INFO) << "All clients have been unregister, automatically exit the server.";
|
||||
RETURN_IF_NOT_OK(Stop());
|
||||
break;
|
||||
}
|
||||
if (state_ == kGdsStopped) {
|
||||
break;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataServer::Stop() {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
async_server_->Stop();
|
||||
#endif
|
||||
set_state(kGdsStopped);
|
||||
graph_data_impl_.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataServer::ClientRegister(int32_t pid) {
|
||||
std::unique_lock<std::mutex> lck(mutex_);
|
||||
MS_LOG(INFO) << "client register pid:" << std::to_string(pid);
|
||||
client_pid_.emplace(pid);
|
||||
if (client_pid_.size() > max_connected_client_num_) {
|
||||
max_connected_client_num_ = client_pid_.size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status GraphDataServer::ClientUnRegister(int32_t pid) {
|
||||
std::unique_lock<std::mutex> lck(mutex_);
|
||||
auto itr = client_pid_.find(pid);
|
||||
if (itr != client_pid_.end()) {
|
||||
client_pid_.erase(itr);
|
||||
MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,196 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
|
||||
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
class GraphDataImpl;
|
||||
|
||||
class GraphDataServer {
|
||||
public:
|
||||
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
|
||||
GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
|
||||
int32_t client_num, bool auto_shutdown);
|
||||
~GraphDataServer() = default;
|
||||
|
||||
Status Init();
|
||||
|
||||
Status Stop();
|
||||
|
||||
Status ClientRegister(int32_t pid);
|
||||
Status ClientUnRegister(int32_t pid);
|
||||
|
||||
enum ServerState state() { return state_; }
|
||||
|
||||
bool IsStoped() {
|
||||
if (state_ == kGdsStopped) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void set_state(enum ServerState state) { state_ = state; }
|
||||
|
||||
Status InitGraphDataImpl();
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status StartAsyncRpcService();
|
||||
#endif
|
||||
Status JudgeAutoShutdownServer();
|
||||
|
||||
std::string dataset_file_;
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
int32_t client_num_;
|
||||
int32_t max_connected_client_num_;
|
||||
bool auto_shutdown_;
|
||||
enum ServerState state_;
|
||||
std::unique_ptr<TaskGroup> tg_; // Class for worker management
|
||||
std::unique_ptr<GraphDataImpl> graph_data_impl_;
|
||||
std::unordered_set<int32_t> client_pid_;
|
||||
std::mutex mutex_;
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
std::unique_ptr<GraphDataServiceImpl> service_impl_;
|
||||
std::unique_ptr<GrpcAsyncServer> async_server_;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
class UntypedCall {
|
||||
public:
|
||||
virtual ~UntypedCall() {}
|
||||
|
||||
virtual Status operator()() = 0;
|
||||
};
|
||||
|
||||
template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage>
|
||||
class CallData : public UntypedCall {
|
||||
public:
|
||||
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
|
||||
using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *,
|
||||
grpc::ServerAsyncResponseWriter<ResponseMessage> *,
|
||||
grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
|
||||
using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *,
|
||||
ResponseMessage *);
|
||||
CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
|
||||
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function)
|
||||
: status_(STATE::CREATE),
|
||||
service_impl_(service_impl),
|
||||
async_service_(async_service),
|
||||
cq_(cq),
|
||||
enqueue_function_(enqueue_function),
|
||||
handle_request_function_(handle_request_function),
|
||||
responder_(&ctx_) {}
|
||||
|
||||
~CallData() = default;
|
||||
|
||||
static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
|
||||
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) {
|
||||
auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>(
|
||||
service_impl, async_service, cq, enqueue_function, handle_request_function);
|
||||
RETURN_IF_NOT_OK((*call)());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status operator()() {
|
||||
if (status_ == STATE::CREATE) {
|
||||
status_ = STATE::PROCESS;
|
||||
(async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this);
|
||||
} else if (status_ == STATE::PROCESS) {
|
||||
EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_);
|
||||
status_ = STATE::FINISH;
|
||||
// new CallData(service_, cq_, this->s_type_);
|
||||
grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_);
|
||||
responder_.Finish(response_, s, this);
|
||||
} else {
|
||||
GPR_ASSERT(status_ == STATE::FINISH);
|
||||
delete this;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
STATE status_;
|
||||
ServiceImpl *service_impl_;
|
||||
AsyncService *async_service_;
|
||||
grpc::ServerCompletionQueue *cq_;
|
||||
EnqueueFunction enqueue_function_;
|
||||
HandleRequestFunction handle_request_function_;
|
||||
grpc::ServerContext ctx_;
|
||||
grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
|
||||
RequestMessage request_;
|
||||
ResponseMessage response_;
|
||||
};
|
||||
|
||||
#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \
|
||||
do { \
|
||||
Status s = \
|
||||
CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \
|
||||
service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \
|
||||
&gnn::GraphDataServiceImpl::method); \
|
||||
RETURN_IF_NOT_OK(s); \
|
||||
} while (0)
|
||||
|
||||
class GraphDataGrpcServer : public GrpcAsyncServer {
|
||||
public:
|
||||
GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl)
|
||||
: GrpcAsyncServer(host, port), service_impl_(service_impl) {}
|
||||
|
||||
Status RegisterService(grpc::ServerBuilder *builder) {
|
||||
builder->RegisterService(&svc_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EnqueueRequest() {
|
||||
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb,
|
||||
GnnClientRegisterResponsePb);
|
||||
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb,
|
||||
GnnClientUnRegisterResponsePb);
|
||||
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb);
|
||||
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ProcessRequest(void *tag) {
|
||||
auto rq = static_cast<UntypedCall *>(tag);
|
||||
RETURN_IF_NOT_OK((*rq)());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
GraphDataServiceImpl *service_impl_;
|
||||
GnnGraphData::AsyncService svc_;
|
||||
};
|
||||
#endif
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,70 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||
#include "proto/gnn_graph_data.grpc.pb.h"
|
||||
#include "proto/gnn_graph_data.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
class GraphDataServer;
|
||||
|
||||
// class GraphDataServiceImpl : public GnnGraphData::Service {
|
||||
class GraphDataServiceImpl {
|
||||
public:
|
||||
GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl);
|
||||
~GraphDataServiceImpl() = default;
|
||||
|
||||
grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request,
|
||||
GnnClientRegisterResponsePb *response);
|
||||
|
||||
grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request,
|
||||
GnnClientUnRegisterResponsePb *response);
|
||||
|
||||
grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
|
||||
GnnGraphDataResponsePb *response);
|
||||
|
||||
grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
|
||||
GnnMetaInfoResponsePb *response);
|
||||
|
||||
Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
|
||||
private:
|
||||
Status FillDefaultFeature(GnnClientRegisterResponsePb *response);
|
||||
|
||||
GraphDataServer *server_;
|
||||
GraphDataImpl *graph_data_impl_;
|
||||
};
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
|
@ -0,0 +1,106 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
using mindrecord::MSRStatus;
|
||||
|
||||
GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) {
|
||||
shard_column_ = std::make_unique<ShardColumn>(shard_column);
|
||||
}
|
||||
|
||||
Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||
std::shared_ptr<Tensor> *tensor) {
|
||||
const unsigned char *data = nullptr;
|
||||
std::unique_ptr<unsigned char[]> data_ptr;
|
||||
uint64_t n_bytes = 0, col_type_size = 1;
|
||||
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||
std::vector<int64_t> column_shape;
|
||||
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
|
||||
&col_type_size, &column_shape);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
|
||||
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
|
||||
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
|
||||
data, tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||
GraphSharedMemory *shared_memory,
|
||||
std::shared_ptr<Tensor> *out_tensor) {
|
||||
const unsigned char *data = nullptr;
|
||||
std::unique_ptr<unsigned char[]> data_ptr;
|
||||
uint64_t n_bytes = 0, col_type_size = 1;
|
||||
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||
std::vector<int64_t> column_shape;
|
||||
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
|
||||
&col_type_size, &column_shape);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
|
||||
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
|
||||
auto fea_itr = tensor->begin<int64_t>();
|
||||
int64_t offset = 0;
|
||||
RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset));
|
||||
*fea_itr = offset;
|
||||
++fea_itr;
|
||||
*fea_itr = n_bytes;
|
||||
*out_tensor = std::move(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||
std::vector<int32_t> *indices) {
|
||||
const unsigned char *data = nullptr;
|
||||
std::unique_ptr<unsigned char[]> data_ptr;
|
||||
uint64_t n_bytes = 0, col_type_size = 1;
|
||||
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||
std::vector<int64_t> column_shape;
|
||||
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
|
||||
&col_type_size, &column_shape);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
|
||||
|
||||
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||
|
||||
for (int i = 0; i < n_bytes; i += col_type_size) {
|
||||
int32_t feature_ind = -1;
|
||||
if (col_type == mindrecord::ColumnInt32) {
|
||||
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
|
||||
} else if (col_type == mindrecord::ColumnInt64) {
|
||||
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
|
||||
}
|
||||
if (feature_ind >= 0) indices->push_back(feature_ind);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/gnn/feature.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/mindrecord/include/shard_column.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
using mindrecord::ShardColumn;
|
||||
|
||||
class GraphFeatureParser {
|
||||
public:
|
||||
explicit GraphFeatureParser(const ShardColumn &shard_column);
|
||||
|
||||
~GraphFeatureParser() = default;
|
||||
|
||||
// @param std::string key - column name
|
||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
|
||||
// @return Status - the status code
|
||||
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind);
|
||||
|
||||
// @param std::string &key - column name
|
||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
|
||||
// @return Status - the status code
|
||||
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor);
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||
GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor);
|
||||
#endif
|
||||
private:
|
||||
std::unique_ptr<ShardColumn> shard_column_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,134 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key)
|
||||
: memory_size_(memory_size),
|
||||
memory_key_(memory_key),
|
||||
memory_ptr_(nullptr),
|
||||
memory_offset_(0),
|
||||
is_new_create_(false) {
|
||||
std::stringstream stream;
|
||||
stream << std::hex << memory_key_;
|
||||
memory_key_str_ = stream.str();
|
||||
}
|
||||
|
||||
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file)
|
||||
: mr_file_(mr_file),
|
||||
memory_size_(memory_size),
|
||||
memory_key_(-1),
|
||||
memory_ptr_(nullptr),
|
||||
memory_offset_(0),
|
||||
is_new_create_(false) {}
|
||||
|
||||
GraphSharedMemory::~GraphSharedMemory() {
|
||||
if (is_new_create_) {
|
||||
(void)DeleteSharedMemory();
|
||||
}
|
||||
}
|
||||
|
||||
Status GraphSharedMemory::CreateSharedMemory() {
|
||||
if (memory_key_ == -1) {
|
||||
// ftok to generate unique key
|
||||
memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_);
|
||||
std::stringstream stream;
|
||||
stream << std::hex << memory_key_;
|
||||
memory_key_str_ = stream.str();
|
||||
}
|
||||
int shmflg = (0666 | IPC_CREAT | IPC_EXCL);
|
||||
Status s = SharedMemoryImpl(shmflg);
|
||||
if (s.IsOk()) {
|
||||
is_new_create_ = true;
|
||||
MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_;
|
||||
shmflg = (0666 | IPC_CREAT);
|
||||
s = SharedMemoryImpl(shmflg);
|
||||
if (!s.IsOk()) {
|
||||
RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphSharedMemory::GetSharedMemory() {
|
||||
int shmflg = 0;
|
||||
RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphSharedMemory::DeleteSharedMemory() {
|
||||
int shmid = shmget(memory_key_, 0, 0);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
|
||||
int result = shmctl(shmid, IPC_RMID, 0);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) {
|
||||
// shmget returns an identifier in shmid
|
||||
int shmid = shmget(memory_key_, memory_size_, shmflg);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
|
||||
|
||||
// shmat to attach to shared memory
|
||||
auto data = shmat(shmid, reinterpret_cast<void *>(0), 0);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_);
|
||||
memory_ptr_ = reinterpret_cast<uint8_t *>(data);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid.");
|
||||
|
||||
std::lock_guard<std::mutex> lck(mutex_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len),
|
||||
"Insufficient shared memory space to insert data.");
|
||||
if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
|
||||
}
|
||||
*offset = memory_offset_;
|
||||
memory_offset_ += len;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space.");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset,
|
||||
"get_data_len is too large, beyond the space of shared memory.");
|
||||
if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
|
||||
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/shm.h>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
const int kGnnSharedMemoryId = 65;
|
||||
|
||||
class GraphSharedMemory {
|
||||
public:
|
||||
explicit GraphSharedMemory(int64_t memory_size, key_t memory_key);
|
||||
explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file);
|
||||
|
||||
~GraphSharedMemory();
|
||||
|
||||
// @param uint8_t** shared_memory - shared memory address
|
||||
// @return Status - the status code
|
||||
Status CreateSharedMemory();
|
||||
|
||||
// @param uint8_t** shared_memory - shared memory address
|
||||
// @return Status - the status code
|
||||
Status GetSharedMemory();
|
||||
|
||||
Status DeleteSharedMemory();
|
||||
|
||||
Status InsertData(const uint8_t *data, int64_t len, int64_t *offset);
|
||||
|
||||
Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len);
|
||||
|
||||
key_t memory_key() { return memory_key_; }
|
||||
|
||||
int64_t memory_size() { return memory_size_; }
|
||||
|
||||
private:
|
||||
Status SharedMemoryImpl(const int &shmflg);
|
||||
|
||||
std::string mr_file_;
|
||||
int64_t memory_size_;
|
||||
key_t memory_key_;
|
||||
std::string memory_key_str_;
|
||||
uint8_t *memory_ptr_;
|
||||
int64_t memory_offset_;
|
||||
std::mutex mutex_;
|
||||
bool is_new_create_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
|
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {}
|
||||
|
||||
GrpcAsyncServer::~GrpcAsyncServer() { Stop(); }
|
||||
|
||||
Status GrpcAsyncServer::Run() {
|
||||
std::string server_address = host_ + ":" + std::to_string(port_);
|
||||
grpc::ServerBuilder builder;
|
||||
// Default message size for gRPC is 4MB. Increase it to 2g-1
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
|
||||
builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
||||
int port_tcpip = 0;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
|
||||
RETURN_IF_NOT_OK(RegisterService(&builder));
|
||||
cq_ = builder.AddCompletionQueue();
|
||||
server_ = builder.BuildAndStart();
|
||||
if (server_) {
|
||||
MS_LOG(INFO) << "Server listening on " << server_address;
|
||||
} else {
|
||||
std::string errMsg = "Fail to start server. ";
|
||||
if (port_tcpip != port_) {
|
||||
errMsg += "Unable to bind to address " + server_address + ".";
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcAsyncServer::HandleRequest() {
|
||||
bool success;
|
||||
void *tag;
|
||||
// We loop through the grpc queue. Each connection if successful
|
||||
// will come back with our own tag which is an instance of CallData
|
||||
// and we simply call its functor. But first we need to create these instances
|
||||
// and inject them into the grpc queue.
|
||||
RETURN_IF_NOT_OK(EnqueueRequest());
|
||||
while (cq_->Next(&tag, &success)) {
|
||||
RETURN_IF_INTERRUPTED();
|
||||
if (success) {
|
||||
RETURN_IF_NOT_OK(ProcessRequest(tag));
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "cq_->Next failed.";
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GrpcAsyncServer::Stop() {
|
||||
if (server_) {
|
||||
server_->Shutdown();
|
||||
}
|
||||
// Always shutdown the completion queue after the server.
|
||||
if (cq_) {
|
||||
cq_->Shutdown();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "grpcpp/impl/codegen/async_unary_call.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \brief Async server base class
|
||||
class GrpcAsyncServer {
|
||||
public:
|
||||
explicit GrpcAsyncServer(const std::string &host, int32_t port);
|
||||
virtual ~GrpcAsyncServer();
|
||||
/// \brief Brings up gRPC server
|
||||
/// \return none
|
||||
Status Run();
|
||||
/// \brief Entry function to handle async server request
|
||||
Status HandleRequest();
|
||||
|
||||
void Stop();
|
||||
|
||||
virtual Status RegisterService(grpc::ServerBuilder *builder) = 0;
|
||||
|
||||
virtual Status EnqueueRequest() = 0;
|
||||
|
||||
virtual Status ProcessRequest(void *tag) = 0;
|
||||
|
||||
protected:
|
||||
int32_t port_;
|
||||
std::string host_;
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue