!4498 Gnn data processing supports distributed scenarios
	
		
	
				
					
				
			Merge pull request !4498 from heleiwang/gnn_distributedpull/4498/MERGE
						commit
						256dccc651
					
				@ -1,9 +1,29 @@
 | 
				
			||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
 | 
				
			||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
 | 
				
			||||
add_library(engine-gnn OBJECT
 | 
				
			||||
    graph.cc
 | 
				
			||||
set(DATASET_ENGINE_GNN_SRC_FILES
 | 
				
			||||
    graph_data_impl.cc
 | 
				
			||||
    graph_data_client.cc
 | 
				
			||||
    graph_data_server.cc
 | 
				
			||||
    graph_loader.cc
 | 
				
			||||
    graph_feature_parser.cc
 | 
				
			||||
    local_node.cc
 | 
				
			||||
    local_edge.cc
 | 
				
			||||
    feature.cc
 | 
				
			||||
    )
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
 | 
				
			||||
    add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES})
 | 
				
			||||
else()
 | 
				
			||||
    set(DATASET_ENGINE_GNN_SRC_FILES
 | 
				
			||||
        ${DATASET_ENGINE_GNN_SRC_FILES}
 | 
				
			||||
        tensor_proto.cc
 | 
				
			||||
        grpc_async_server.cc
 | 
				
			||||
        graph_data_service_impl.cc
 | 
				
			||||
        graph_shared_memory.cc)
 | 
				
			||||
 | 
				
			||||
    ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto")
 | 
				
			||||
    ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto")
 | 
				
			||||
 | 
				
			||||
    add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS})
 | 
				
			||||
    add_dependencies(engine-gnn mindspore::protobuf)
 | 
				
			||||
endif()
 | 
				
			||||
 | 
				
			||||
@ -0,0 +1,103 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
syntax = "proto3";
 | 
				
			||||
 | 
				
			||||
package mindspore.dataset;
 | 
				
			||||
 | 
				
			||||
import "gnn_tensor.proto";
 | 
				
			||||
 | 
				
			||||
message GnnClientRegisterRequestPb {
 | 
				
			||||
  int32 pid = 1;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnFeatureInfoPb {
 | 
				
			||||
  int32 type = 1;
 | 
				
			||||
  TensorPb feature = 2;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnClientRegisterResponsePb {
 | 
				
			||||
  string error_msg = 1;
 | 
				
			||||
  string data_schema = 2;
 | 
				
			||||
  int64 shared_memory_key = 3;
 | 
				
			||||
  int64 shared_memory_size = 4;
 | 
				
			||||
  repeated GnnFeatureInfoPb default_node_feature = 5;
 | 
				
			||||
  repeated GnnFeatureInfoPb default_edge_feature = 6;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnClientUnRegisterRequestPb {
 | 
				
			||||
  int32 pid = 1;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnClientUnRegisterResponsePb {
 | 
				
			||||
  string error_msg = 1;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
enum GnnOpName {
 | 
				
			||||
  GET_ALL_NODES = 0;
 | 
				
			||||
  GET_ALL_EDGES = 1;
 | 
				
			||||
  GET_NODES_FROM_EDGES = 2;
 | 
				
			||||
  GET_ALL_NEIGHBORS = 3;
 | 
				
			||||
  GET_SAMPLED_NEIGHBORS = 4;
 | 
				
			||||
  GET_NEG_SAMPLED_NEIGHBORS = 5;
 | 
				
			||||
  RANDOM_WALK = 6;
 | 
				
			||||
  GET_NODE_FEATURE = 7;
 | 
				
			||||
  GET_EDGE_FEATURE = 8;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnRandomWalkPb {
 | 
				
			||||
  float p = 1;
 | 
				
			||||
  float q = 2;
 | 
				
			||||
  int32 default_id = 3;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnGraphDataRequestPb {
 | 
				
			||||
  GnnOpName op_name = 1;
 | 
				
			||||
  repeated int32 id = 2; // node id or edge id
 | 
				
			||||
  repeated int32 type = 3; //node type or edge type or neighbor type or feature type
 | 
				
			||||
  repeated int32 number = 4; // samples number
 | 
				
			||||
  TensorPb id_tensor = 5; // input ids ,node id or edge id
 | 
				
			||||
  GnnRandomWalkPb random_walk = 6;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnGraphDataResponsePb {
 | 
				
			||||
  string error_msg = 1;
 | 
				
			||||
  repeated TensorPb result_data = 2;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnMetaInfoRequestPb {
 | 
				
			||||
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnNodeEdgeInfoPb {
 | 
				
			||||
  int32 type = 1;
 | 
				
			||||
  int32 num = 2;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message GnnMetaInfoResponsePb {
 | 
				
			||||
  string error_msg = 1;
 | 
				
			||||
  repeated GnnNodeEdgeInfoPb node_info = 2;
 | 
				
			||||
  repeated GnnNodeEdgeInfoPb edge_info = 3;
 | 
				
			||||
  repeated int32 node_feature_type = 4;
 | 
				
			||||
  repeated int32 edge_feature_type = 5;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
service GnnGraphData {
 | 
				
			||||
  rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb);
 | 
				
			||||
  rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb);
 | 
				
			||||
  rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb);
 | 
				
			||||
  rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb);
 | 
				
			||||
}
 | 
				
			||||
@ -0,0 +1,42 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2019 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
syntax = "proto3";
 | 
				
			||||
 | 
				
			||||
package mindspore.dataset;
 | 
				
			||||
 | 
				
			||||
enum DataTypePb {
 | 
				
			||||
  DE_PB_UNKNOWN = 0;
 | 
				
			||||
  DE_PB_BOOL = 1;
 | 
				
			||||
  DE_PB_INT8 = 2;
 | 
				
			||||
  DE_PB_UINT8 = 3;
 | 
				
			||||
  DE_PB_INT16 = 4;
 | 
				
			||||
  DE_PB_UINT16 = 5;
 | 
				
			||||
  DE_PB_INT32 = 6;
 | 
				
			||||
  DE_PB_UINT32 = 7;
 | 
				
			||||
  DE_PB_INT64 = 8;
 | 
				
			||||
  DE_PB_UINT64 = 9;
 | 
				
			||||
  DE_PB_FLOAT16 = 10;
 | 
				
			||||
  DE_PB_FLOAT32 = 11;
 | 
				
			||||
  DE_PB_FLOAT64 = 12;
 | 
				
			||||
  DE_PB_STRING = 13;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message TensorPb {
 | 
				
			||||
  repeated int64 dims = 1; // tensor shape info
 | 
				
			||||
  DataTypePb tensor_type = 2; // tensor content data type
 | 
				
			||||
  bytes data = 3; // tensor data
 | 
				
			||||
}
 | 
				
			||||
@ -0,0 +1,134 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
 | 
				
			||||
 | 
				
			||||
#include <map>
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include <utility>
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/core/tensor.h"
 | 
				
			||||
#include "minddata/dataset/core/tensor_row.h"
 | 
				
			||||
#include "minddata/dataset/engine/gnn/feature.h"
 | 
				
			||||
#include "minddata/dataset/engine/gnn/node.h"
 | 
				
			||||
#include "minddata/dataset/engine/gnn/edge.h"
 | 
				
			||||
#include "minddata/dataset/util/status.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
struct MetaInfo {
 | 
				
			||||
  std::vector<NodeType> node_type;
 | 
				
			||||
  std::vector<EdgeType> edge_type;
 | 
				
			||||
  std::map<NodeType, NodeIdType> node_num;
 | 
				
			||||
  std::map<EdgeType, EdgeIdType> edge_num;
 | 
				
			||||
  std::vector<FeatureType> node_feature_type;
 | 
				
			||||
  std::vector<FeatureType> edge_feature_type;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class GraphData {
 | 
				
			||||
 public:
 | 
				
			||||
  // Get all nodes from the graph.
 | 
				
			||||
  // @param NodeType node_type - type of node
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned nodes id
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Get all edges from the graph.
 | 
				
			||||
  // @param NodeType edge_type - type of edge
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned edge ids
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Get the node id from the edge.
 | 
				
			||||
  // @param std::vector<EdgeIdType> edge_list - List of edges
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned node ids
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // All neighbors of the acquisition node.
 | 
				
			||||
  // @param std::vector<NodeType> node_list - List of nodes
 | 
				
			||||
  // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
 | 
				
			||||
  // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
 | 
				
			||||
  // is not enough, fill in tensor as -1.
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
 | 
				
			||||
                                 std::shared_ptr<Tensor> *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Get sampled neighbors.
 | 
				
			||||
  // @param std::vector<NodeType> node_list - List of nodes
 | 
				
			||||
  // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
 | 
				
			||||
  // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
 | 
				
			||||
                                     const std::vector<NodeIdType> &neighbor_nums,
 | 
				
			||||
                                     const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Get negative sampled neighbors.
 | 
				
			||||
  // @param std::vector<NodeType> node_list - List of nodes
 | 
				
			||||
  // @param NodeIdType samples_num - Number of neighbors sampled
 | 
				
			||||
  // @param NodeType neg_neighbor_type - The type of negative neighbor.
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
 | 
				
			||||
                                        NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Node2vec random walk.
 | 
				
			||||
  // @param std::vector<NodeIdType> node_list - List of nodes
 | 
				
			||||
  // @param std::vector<NodeType> meta_path - node type of each step
 | 
				
			||||
  // @param float step_home_param - return hyper parameter in node2vec algorithm
 | 
				
			||||
  // @param float step_away_param - inout hyper parameter in node2vec algorithm
 | 
				
			||||
  // @param NodeIdType default_node - default node id
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
 | 
				
			||||
                            float step_home_param, float step_away_param, NodeIdType default_node,
 | 
				
			||||
                            std::shared_ptr<Tensor> *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Get the feature of a node
 | 
				
			||||
  // @param std::shared_ptr<Tensor> nodes - List of nodes
 | 
				
			||||
  // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
 | 
				
			||||
  // does not exist.
 | 
				
			||||
  // @param TensorRow *out - Returned features
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
 | 
				
			||||
                                TensorRow *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Get the feature of a edge
 | 
				
			||||
  // @param std::shared_ptr<Tensor> edges - List of edges
 | 
				
			||||
  // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
 | 
				
			||||
  // does not exist.
 | 
				
			||||
  // @param Tensor *out - Returned features
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
 | 
				
			||||
                                TensorRow *out) = 0;
 | 
				
			||||
 | 
				
			||||
  // Return meta information to python layer
 | 
				
			||||
  virtual Status GraphInfo(py::dict *out) = 0;
 | 
				
			||||
 | 
				
			||||
  virtual Status Init() = 0;
 | 
				
			||||
 | 
				
			||||
  virtual Status Stop() = 0;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,185 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
 | 
				
			||||
 | 
				
			||||
#include <algorithm>
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <map>
 | 
				
			||||
#include <unordered_map>
 | 
				
			||||
#include <unordered_set>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include <utility>
 | 
				
			||||
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
#include "proto/gnn_graph_data.grpc.pb.h"
 | 
				
			||||
#include "proto/gnn_graph_data.pb.h"
 | 
				
			||||
#endif
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_data.h"
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
 | 
				
			||||
#endif
 | 
				
			||||
#include "minddata/mindrecord/include/common/shard_utils.h"
 | 
				
			||||
#include "minddata/mindrecord/include/shard_column.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
class GraphDataClient : public GraphData {
 | 
				
			||||
 public:
 | 
				
			||||
  // Constructor
 | 
				
			||||
  // @param std::string dataset_file -
 | 
				
			||||
  // @param int32_t num_workers - number of parallel threads
 | 
				
			||||
  GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port);
 | 
				
			||||
 | 
				
			||||
  ~GraphDataClient();
 | 
				
			||||
 | 
				
			||||
  Status Init() override;
 | 
				
			||||
 | 
				
			||||
  Status Stop() override;
 | 
				
			||||
 | 
				
			||||
  // Get all nodes from the graph.
 | 
				
			||||
  // @param NodeType node_type - type of node
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned nodes id
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
 | 
				
			||||
 | 
				
			||||
  // Get all edges from the graph.
 | 
				
			||||
  // @param NodeType edge_type - type of edge
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned edge ids
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
 | 
				
			||||
 | 
				
			||||
  // Get the node id from the edge.
 | 
				
			||||
  // @param std::vector<EdgeIdType> edge_list - List of edges
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned node ids
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
 | 
				
			||||
 | 
				
			||||
  // All neighbors of the acquisition node.
 | 
				
			||||
  // @param std::vector<NodeType> node_list - List of nodes
 | 
				
			||||
  // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
 | 
				
			||||
  // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
 | 
				
			||||
  // is not enough, fill in tensor as -1.
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
 | 
				
			||||
                         std::shared_ptr<Tensor> *out) override;
 | 
				
			||||
 | 
				
			||||
  // Get sampled neighbors.
 | 
				
			||||
  // @param std::vector<NodeType> node_list - List of nodes
 | 
				
			||||
  // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
 | 
				
			||||
  // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
 | 
				
			||||
                             const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
 | 
				
			||||
 | 
				
			||||
  // Get negative sampled neighbors.
 | 
				
			||||
  // @param std::vector<NodeType> node_list - List of nodes
 | 
				
			||||
  // @param NodeIdType samples_num - Number of neighbors sampled
 | 
				
			||||
  // @param NodeType neg_neighbor_type - The type of negative neighbor.
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
 | 
				
			||||
                                NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
 | 
				
			||||
 | 
				
			||||
  // Node2vec random walk.
 | 
				
			||||
  // @param std::vector<NodeIdType> node_list - List of nodes
 | 
				
			||||
  // @param std::vector<NodeType> meta_path - node type of each step
 | 
				
			||||
  // @param float step_home_param - return hyper parameter in node2vec algorithm
 | 
				
			||||
  // @param float step_away_param - inout hyper parameter in node2vec algorithm
 | 
				
			||||
  // @param NodeIdType default_node - default node id
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
 | 
				
			||||
                    float step_home_param, float step_away_param, NodeIdType default_node,
 | 
				
			||||
                    std::shared_ptr<Tensor> *out) override;
 | 
				
			||||
 | 
				
			||||
  // Get the feature of a node
 | 
				
			||||
  // @param std::shared_ptr<Tensor> nodes - List of nodes
 | 
				
			||||
  // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
 | 
				
			||||
  // does not exist.
 | 
				
			||||
  // @param TensorRow *out - Returned features
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
 | 
				
			||||
                        TensorRow *out) override;
 | 
				
			||||
 | 
				
			||||
  // Get the feature of a edge
 | 
				
			||||
  // @param std::shared_ptr<Tensor> edges - List of edges
 | 
				
			||||
  // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
 | 
				
			||||
  // does not exist.
 | 
				
			||||
  // @param Tensor *out - Returned features
 | 
				
			||||
  // @return Status - The error code return
 | 
				
			||||
  Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
 | 
				
			||||
                        TensorRow *out) override;
 | 
				
			||||
 | 
				
			||||
  // Return meta information to python layer
 | 
				
			||||
  Status GraphInfo(py::dict *out) override;
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
  Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
 | 
				
			||||
                                    const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
 | 
				
			||||
 | 
				
			||||
  Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
 | 
				
			||||
                                    const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
 | 
				
			||||
 | 
				
			||||
  Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
 | 
				
			||||
 | 
				
			||||
  Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
 | 
				
			||||
 | 
				
			||||
  Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response);
 | 
				
			||||
 | 
				
			||||
  Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
 | 
				
			||||
                            std::shared_ptr<Tensor> *out);
 | 
				
			||||
 | 
				
			||||
  Status RegisterToServer();
 | 
				
			||||
 | 
				
			||||
  Status UnRegisterToServer();
 | 
				
			||||
 | 
				
			||||
  Status InitFeatureParser();
 | 
				
			||||
 | 
				
			||||
  Status CheckPid() {
 | 
				
			||||
    CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(),
 | 
				
			||||
                                 "Multi-process mode is not supported, please change to use multi-thread");
 | 
				
			||||
    return Status::OK();
 | 
				
			||||
  }
 | 
				
			||||
#endif
 | 
				
			||||
 | 
				
			||||
  std::string dataset_file_;
 | 
				
			||||
  std::string host_;
 | 
				
			||||
  int32_t port_;
 | 
				
			||||
  int32_t pid_;
 | 
				
			||||
  mindrecord::json data_schema_;
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
  std::unique_ptr<GnnGraphData::Stub> stub_;
 | 
				
			||||
  key_t shared_memory_key_;
 | 
				
			||||
  int64_t shared_memory_size_;
 | 
				
			||||
  std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
 | 
				
			||||
  std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
 | 
				
			||||
  std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_;
 | 
				
			||||
  std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_;
 | 
				
			||||
#endif
 | 
				
			||||
  bool registered_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,133 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_data_server.h"
 | 
				
			||||
 | 
				
			||||
#include <algorithm>
 | 
				
			||||
#include <functional>
 | 
				
			||||
#include <iterator>
 | 
				
			||||
#include <numeric>
 | 
				
			||||
#include <utility>
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/core/tensor_shape.h"
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
 | 
				
			||||
#include "minddata/dataset/util/random.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname,
 | 
				
			||||
                                 int32_t port, int32_t client_num, bool auto_shutdown)
 | 
				
			||||
    : dataset_file_(dataset_file),
 | 
				
			||||
      num_workers_(num_workers),
 | 
				
			||||
      client_num_(client_num),
 | 
				
			||||
      max_connected_client_num_(0),
 | 
				
			||||
      auto_shutdown_(auto_shutdown),
 | 
				
			||||
      state_(kGdsUninit) {
 | 
				
			||||
  tg_ = std::make_unique<TaskGroup>();
 | 
				
			||||
  graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true);
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
  service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get());
 | 
				
			||||
  async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get());
 | 
				
			||||
#endif
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphDataServer::Init() {
 | 
				
			||||
#if defined(_WIN32) || defined(_WIN64)
 | 
				
			||||
  RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS");
 | 
				
			||||
#else
 | 
				
			||||
  set_state(kGdsInitializing);
 | 
				
			||||
  RETURN_IF_NOT_OK(async_server_->Run());
 | 
				
			||||
  // RETURN_IF_NOT_OK(InitGraphDataImpl());
 | 
				
			||||
  RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this)));
 | 
				
			||||
  for (int32_t i = 0; i < num_workers_; ++i) {
 | 
				
			||||
    RETURN_IF_NOT_OK(
 | 
				
			||||
      tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this)));
 | 
				
			||||
  }
 | 
				
			||||
  if (auto_shutdown_) {
 | 
				
			||||
    RETURN_IF_NOT_OK(
 | 
				
			||||
      tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this)));
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
#endif
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphDataServer::InitGraphDataImpl() {
 | 
				
			||||
  TaskManager::FindMe()->Post();
 | 
				
			||||
  Status s = graph_data_impl_->Init();
 | 
				
			||||
  if (s.IsOk()) {
 | 
				
			||||
    set_state(kGdsRunning);
 | 
				
			||||
  } else {
 | 
				
			||||
    (void)Stop();
 | 
				
			||||
  }
 | 
				
			||||
  return s;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
Status GraphDataServer::StartAsyncRpcService() {
 | 
				
			||||
  TaskManager::FindMe()->Post();
 | 
				
			||||
  RETURN_IF_NOT_OK(async_server_->HandleRequest());
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
#endif
 | 
				
			||||
 | 
				
			||||
Status GraphDataServer::JudgeAutoShutdownServer() {
 | 
				
			||||
  TaskManager::FindMe()->Post();
 | 
				
			||||
  while (true) {
 | 
				
			||||
    if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) {
 | 
				
			||||
      MS_LOG(INFO) << "All clients have been unregister, automatically exit the server.";
 | 
				
			||||
      RETURN_IF_NOT_OK(Stop());
 | 
				
			||||
      break;
 | 
				
			||||
    }
 | 
				
			||||
    if (state_ == kGdsStopped) {
 | 
				
			||||
      break;
 | 
				
			||||
    }
 | 
				
			||||
    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphDataServer::Stop() {
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
  async_server_->Stop();
 | 
				
			||||
#endif
 | 
				
			||||
  set_state(kGdsStopped);
 | 
				
			||||
  graph_data_impl_.reset();
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphDataServer::ClientRegister(int32_t pid) {
 | 
				
			||||
  std::unique_lock<std::mutex> lck(mutex_);
 | 
				
			||||
  MS_LOG(INFO) << "client register pid:" << std::to_string(pid);
 | 
				
			||||
  client_pid_.emplace(pid);
 | 
				
			||||
  if (client_pid_.size() > max_connected_client_num_) {
 | 
				
			||||
    max_connected_client_num_ = client_pid_.size();
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
Status GraphDataServer::ClientUnRegister(int32_t pid) {
 | 
				
			||||
  std::unique_lock<std::mutex> lck(mutex_);
 | 
				
			||||
  auto itr = client_pid_.find(pid);
 | 
				
			||||
  if (itr != client_pid_.end()) {
 | 
				
			||||
    client_pid_.erase(itr);
 | 
				
			||||
    MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid);
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,196 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <mutex>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <unordered_set>
 | 
				
			||||
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
#include "grpcpp/grpcpp.h"
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
 | 
				
			||||
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
 | 
				
			||||
#endif
 | 
				
			||||
#include "minddata/dataset/util/task_manager.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
class GraphDataImpl;
 | 
				
			||||
 | 
				
			||||
class GraphDataServer {
 | 
				
			||||
 public:
 | 
				
			||||
  enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
 | 
				
			||||
  GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
 | 
				
			||||
                  int32_t client_num, bool auto_shutdown);
 | 
				
			||||
  ~GraphDataServer() = default;
 | 
				
			||||
 | 
				
			||||
  Status Init();
 | 
				
			||||
 | 
				
			||||
  Status Stop();
 | 
				
			||||
 | 
				
			||||
  Status ClientRegister(int32_t pid);
 | 
				
			||||
  Status ClientUnRegister(int32_t pid);
 | 
				
			||||
 | 
				
			||||
  enum ServerState state() { return state_; }
 | 
				
			||||
 | 
				
			||||
  bool IsStoped() {
 | 
				
			||||
    if (state_ == kGdsStopped) {
 | 
				
			||||
      return true;
 | 
				
			||||
    } else {
 | 
				
			||||
      return false;
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  void set_state(enum ServerState state) { state_ = state; }
 | 
				
			||||
 | 
				
			||||
  Status InitGraphDataImpl();
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
  Status StartAsyncRpcService();
 | 
				
			||||
#endif
 | 
				
			||||
  Status JudgeAutoShutdownServer();
 | 
				
			||||
 | 
				
			||||
  std::string dataset_file_;
 | 
				
			||||
  int32_t num_workers_;  // The number of worker threads
 | 
				
			||||
  int32_t client_num_;
 | 
				
			||||
  int32_t max_connected_client_num_;
 | 
				
			||||
  bool auto_shutdown_;
 | 
				
			||||
  enum ServerState state_;
 | 
				
			||||
  std::unique_ptr<TaskGroup> tg_;  // Class for worker management
 | 
				
			||||
  std::unique_ptr<GraphDataImpl> graph_data_impl_;
 | 
				
			||||
  std::unordered_set<int32_t> client_pid_;
 | 
				
			||||
  std::mutex mutex_;
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
  std::unique_ptr<GraphDataServiceImpl> service_impl_;
 | 
				
			||||
  std::unique_ptr<GrpcAsyncServer> async_server_;
 | 
				
			||||
#endif
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
class UntypedCall {
 | 
				
			||||
 public:
 | 
				
			||||
  virtual ~UntypedCall() {}
 | 
				
			||||
 | 
				
			||||
  virtual Status operator()() = 0;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage>
 | 
				
			||||
class CallData : public UntypedCall {
 | 
				
			||||
 public:
 | 
				
			||||
  enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
 | 
				
			||||
  using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *,
 | 
				
			||||
                                                 grpc::ServerAsyncResponseWriter<ResponseMessage> *,
 | 
				
			||||
                                                 grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
 | 
				
			||||
  using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *,
 | 
				
			||||
                                                              ResponseMessage *);
 | 
				
			||||
  CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
 | 
				
			||||
           EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function)
 | 
				
			||||
      : status_(STATE::CREATE),
 | 
				
			||||
        service_impl_(service_impl),
 | 
				
			||||
        async_service_(async_service),
 | 
				
			||||
        cq_(cq),
 | 
				
			||||
        enqueue_function_(enqueue_function),
 | 
				
			||||
        handle_request_function_(handle_request_function),
 | 
				
			||||
        responder_(&ctx_) {}
 | 
				
			||||
 | 
				
			||||
  ~CallData() = default;
 | 
				
			||||
 | 
				
			||||
  static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
 | 
				
			||||
                               EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) {
 | 
				
			||||
    auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>(
 | 
				
			||||
      service_impl, async_service, cq, enqueue_function, handle_request_function);
 | 
				
			||||
    RETURN_IF_NOT_OK((*call)());
 | 
				
			||||
    return Status::OK();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  Status operator()() {
 | 
				
			||||
    if (status_ == STATE::CREATE) {
 | 
				
			||||
      status_ = STATE::PROCESS;
 | 
				
			||||
      (async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this);
 | 
				
			||||
    } else if (status_ == STATE::PROCESS) {
 | 
				
			||||
      EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_);
 | 
				
			||||
      status_ = STATE::FINISH;
 | 
				
			||||
      // new CallData(service_, cq_, this->s_type_);
 | 
				
			||||
      grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_);
 | 
				
			||||
      responder_.Finish(response_, s, this);
 | 
				
			||||
    } else {
 | 
				
			||||
      GPR_ASSERT(status_ == STATE::FINISH);
 | 
				
			||||
      delete this;
 | 
				
			||||
    }
 | 
				
			||||
    return Status::OK();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  STATE status_;
 | 
				
			||||
  ServiceImpl *service_impl_;
 | 
				
			||||
  AsyncService *async_service_;
 | 
				
			||||
  grpc::ServerCompletionQueue *cq_;
 | 
				
			||||
  EnqueueFunction enqueue_function_;
 | 
				
			||||
  HandleRequestFunction handle_request_function_;
 | 
				
			||||
  grpc::ServerContext ctx_;
 | 
				
			||||
  grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
 | 
				
			||||
  RequestMessage request_;
 | 
				
			||||
  ResponseMessage response_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg)                       \
 | 
				
			||||
  do {                                                                                                            \
 | 
				
			||||
    Status s =                                                                                                    \
 | 
				
			||||
      CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \
 | 
				
			||||
        service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method,                            \
 | 
				
			||||
        &gnn::GraphDataServiceImpl::method);                                                                      \
 | 
				
			||||
    RETURN_IF_NOT_OK(s);                                                                                          \
 | 
				
			||||
  } while (0)
 | 
				
			||||
 | 
				
			||||
class GraphDataGrpcServer : public GrpcAsyncServer {
 | 
				
			||||
 public:
 | 
				
			||||
  GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl)
 | 
				
			||||
      : GrpcAsyncServer(host, port), service_impl_(service_impl) {}
 | 
				
			||||
 | 
				
			||||
  Status RegisterService(grpc::ServerBuilder *builder) {
 | 
				
			||||
    builder->RegisterService(&svc_);
 | 
				
			||||
    return Status::OK();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  Status EnqueueRequest() {
 | 
				
			||||
    ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb,
 | 
				
			||||
                    GnnClientRegisterResponsePb);
 | 
				
			||||
    ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb,
 | 
				
			||||
                    GnnClientUnRegisterResponsePb);
 | 
				
			||||
    ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb);
 | 
				
			||||
    ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb);
 | 
				
			||||
    return Status::OK();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  Status ProcessRequest(void *tag) {
 | 
				
			||||
    auto rq = static_cast<UntypedCall *>(tag);
 | 
				
			||||
    RETURN_IF_NOT_OK((*rq)());
 | 
				
			||||
    return Status::OK();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  GraphDataServiceImpl *service_impl_;
 | 
				
			||||
  GnnGraphData::AsyncService svc_;
 | 
				
			||||
};
 | 
				
			||||
#endif
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,70 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
 | 
				
			||||
#include "proto/gnn_graph_data.grpc.pb.h"
 | 
				
			||||
#include "proto/gnn_graph_data.pb.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
class GraphDataServer;
 | 
				
			||||
 | 
				
			||||
// class GraphDataServiceImpl : public GnnGraphData::Service {
 | 
				
			||||
class GraphDataServiceImpl {
 | 
				
			||||
 public:
 | 
				
			||||
  GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl);
 | 
				
			||||
  ~GraphDataServiceImpl() = default;
 | 
				
			||||
 | 
				
			||||
  grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request,
 | 
				
			||||
                              GnnClientRegisterResponsePb *response);
 | 
				
			||||
 | 
				
			||||
  grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request,
 | 
				
			||||
                                GnnClientUnRegisterResponsePb *response);
 | 
				
			||||
 | 
				
			||||
  grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
 | 
				
			||||
                            GnnGraphDataResponsePb *response);
 | 
				
			||||
 | 
				
			||||
  grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
 | 
				
			||||
                           GnnMetaInfoResponsePb *response);
 | 
				
			||||
 | 
				
			||||
  Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
  Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  Status FillDefaultFeature(GnnClientRegisterResponsePb *response);
 | 
				
			||||
 | 
				
			||||
  GraphDataServer *server_;
 | 
				
			||||
  GraphDataImpl *graph_data_impl_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
 | 
				
			||||
@ -0,0 +1,106 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <utility>
 | 
				
			||||
 | 
				
			||||
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
using mindrecord::MSRStatus;
 | 
				
			||||
 | 
				
			||||
GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) {
 | 
				
			||||
  shard_column_ = std::make_unique<ShardColumn>(shard_column);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
 | 
				
			||||
                                             std::shared_ptr<Tensor> *tensor) {
 | 
				
			||||
  const unsigned char *data = nullptr;
 | 
				
			||||
  std::unique_ptr<unsigned char[]> data_ptr;
 | 
				
			||||
  uint64_t n_bytes = 0, col_type_size = 1;
 | 
				
			||||
  mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
 | 
				
			||||
  std::vector<int64_t> column_shape;
 | 
				
			||||
  MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
 | 
				
			||||
                                                     &col_type_size, &column_shape);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
 | 
				
			||||
  if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
 | 
				
			||||
  RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
 | 
				
			||||
                                            std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
 | 
				
			||||
                                            data, tensor));
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
 | 
				
			||||
                                                     GraphSharedMemory *shared_memory,
 | 
				
			||||
                                                     std::shared_ptr<Tensor> *out_tensor) {
 | 
				
			||||
  const unsigned char *data = nullptr;
 | 
				
			||||
  std::unique_ptr<unsigned char[]> data_ptr;
 | 
				
			||||
  uint64_t n_bytes = 0, col_type_size = 1;
 | 
				
			||||
  mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
 | 
				
			||||
  std::vector<int64_t> column_shape;
 | 
				
			||||
  MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
 | 
				
			||||
                                                     &col_type_size, &column_shape);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
 | 
				
			||||
  if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
 | 
				
			||||
  std::shared_ptr<Tensor> tensor;
 | 
				
			||||
  RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
 | 
				
			||||
  auto fea_itr = tensor->begin<int64_t>();
 | 
				
			||||
  int64_t offset = 0;
 | 
				
			||||
  RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset));
 | 
				
			||||
  *fea_itr = offset;
 | 
				
			||||
  ++fea_itr;
 | 
				
			||||
  *fea_itr = n_bytes;
 | 
				
			||||
  *out_tensor = std::move(tensor);
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
#endif
 | 
				
			||||
 | 
				
			||||
Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
 | 
				
			||||
                                            std::vector<int32_t> *indices) {
 | 
				
			||||
  const unsigned char *data = nullptr;
 | 
				
			||||
  std::unique_ptr<unsigned char[]> data_ptr;
 | 
				
			||||
  uint64_t n_bytes = 0, col_type_size = 1;
 | 
				
			||||
  mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
 | 
				
			||||
  std::vector<int64_t> column_shape;
 | 
				
			||||
  MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
 | 
				
			||||
                                                     &col_type_size, &column_shape);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
 | 
				
			||||
 | 
				
			||||
  if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
 | 
				
			||||
 | 
				
			||||
  for (int i = 0; i < n_bytes; i += col_type_size) {
 | 
				
			||||
    int32_t feature_ind = -1;
 | 
				
			||||
    if (col_type == mindrecord::ColumnInt32) {
 | 
				
			||||
      feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
 | 
				
			||||
    } else if (col_type == mindrecord::ColumnInt64) {
 | 
				
			||||
      feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
 | 
				
			||||
    } else {
 | 
				
			||||
      RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
 | 
				
			||||
    }
 | 
				
			||||
    if (feature_ind >= 0) indices->push_back(feature_ind);
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,67 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <queue>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <unordered_map>
 | 
				
			||||
#include <unordered_set>
 | 
				
			||||
#include <vector>
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/core/data_type.h"
 | 
				
			||||
#include "minddata/dataset/core/tensor.h"
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
 | 
				
			||||
#endif
 | 
				
			||||
#include "minddata/dataset/engine/gnn/feature.h"
 | 
				
			||||
#include "minddata/dataset/util/status.h"
 | 
				
			||||
#include "minddata/mindrecord/include/shard_column.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
using mindrecord::ShardColumn;
 | 
				
			||||
 | 
				
			||||
class GraphFeatureParser {
 | 
				
			||||
 public:
 | 
				
			||||
  explicit GraphFeatureParser(const ShardColumn &shard_column);
 | 
				
			||||
 | 
				
			||||
  ~GraphFeatureParser() = default;
 | 
				
			||||
 | 
				
			||||
  // @param std::string key - column name
 | 
				
			||||
  // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
 | 
				
			||||
  // @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
 | 
				
			||||
  // @return Status - the status code
 | 
				
			||||
  Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind);
 | 
				
			||||
 | 
				
			||||
  // @param std::string &key - column name
 | 
				
			||||
  // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
 | 
				
			||||
  // @param std::shared_ptr<Tensor> *tensor - return value feature tensor
 | 
				
			||||
  // @return Status - the status code
 | 
				
			||||
  Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor);
 | 
				
			||||
#if !defined(_WIN32) && !defined(_WIN64)
 | 
				
			||||
  Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
 | 
				
			||||
                                   GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor);
 | 
				
			||||
#endif
 | 
				
			||||
 private:
 | 
				
			||||
  std::unique_ptr<ShardColumn> shard_column_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,134 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
 | 
				
			||||
 | 
				
			||||
#include <string>
 | 
				
			||||
 | 
				
			||||
#include "utils/log_adapter.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key)
 | 
				
			||||
    : memory_size_(memory_size),
 | 
				
			||||
      memory_key_(memory_key),
 | 
				
			||||
      memory_ptr_(nullptr),
 | 
				
			||||
      memory_offset_(0),
 | 
				
			||||
      is_new_create_(false) {
 | 
				
			||||
  std::stringstream stream;
 | 
				
			||||
  stream << std::hex << memory_key_;
 | 
				
			||||
  memory_key_str_ = stream.str();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file)
 | 
				
			||||
    : mr_file_(mr_file),
 | 
				
			||||
      memory_size_(memory_size),
 | 
				
			||||
      memory_key_(-1),
 | 
				
			||||
      memory_ptr_(nullptr),
 | 
				
			||||
      memory_offset_(0),
 | 
				
			||||
      is_new_create_(false) {}
 | 
				
			||||
 | 
				
			||||
GraphSharedMemory::~GraphSharedMemory() {
 | 
				
			||||
  if (is_new_create_) {
 | 
				
			||||
    (void)DeleteSharedMemory();
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphSharedMemory::CreateSharedMemory() {
 | 
				
			||||
  if (memory_key_ == -1) {
 | 
				
			||||
    // ftok to generate unique key
 | 
				
			||||
    memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId);
 | 
				
			||||
    CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_);
 | 
				
			||||
    std::stringstream stream;
 | 
				
			||||
    stream << std::hex << memory_key_;
 | 
				
			||||
    memory_key_str_ = stream.str();
 | 
				
			||||
  }
 | 
				
			||||
  int shmflg = (0666 | IPC_CREAT | IPC_EXCL);
 | 
				
			||||
  Status s = SharedMemoryImpl(shmflg);
 | 
				
			||||
  if (s.IsOk()) {
 | 
				
			||||
    is_new_create_ = true;
 | 
				
			||||
    MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_;
 | 
				
			||||
  } else {
 | 
				
			||||
    MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_;
 | 
				
			||||
    shmflg = (0666 | IPC_CREAT);
 | 
				
			||||
    s = SharedMemoryImpl(shmflg);
 | 
				
			||||
    if (!s.IsOk()) {
 | 
				
			||||
      RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphSharedMemory::GetSharedMemory() {
 | 
				
			||||
  int shmflg = 0;
 | 
				
			||||
  RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg));
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphSharedMemory::DeleteSharedMemory() {
 | 
				
			||||
  int shmid = shmget(memory_key_, 0, 0);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
 | 
				
			||||
  int result = shmctl(shmid, IPC_RMID, 0);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_);
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) {
 | 
				
			||||
  // shmget returns an identifier in shmid
 | 
				
			||||
  int shmid = shmget(memory_key_, memory_size_, shmflg);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
 | 
				
			||||
 | 
				
			||||
  // shmat to attach to shared memory
 | 
				
			||||
  auto data = shmat(shmid, reinterpret_cast<void *>(0), 0);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_);
 | 
				
			||||
  memory_ptr_ = reinterpret_cast<uint8_t *>(data);
 | 
				
			||||
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) {
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid.");
 | 
				
			||||
 | 
				
			||||
  std::lock_guard<std::mutex> lck(mutex_);
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len),
 | 
				
			||||
                               "Insufficient shared memory space to insert data.");
 | 
				
			||||
  if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) {
 | 
				
			||||
    RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
 | 
				
			||||
  }
 | 
				
			||||
  *offset = memory_offset_;
 | 
				
			||||
  memory_offset_ += len;
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) {
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid.");
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space.");
 | 
				
			||||
 | 
				
			||||
  CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset,
 | 
				
			||||
                               "get_data_len is too large, beyond the space of shared memory.");
 | 
				
			||||
  if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) {
 | 
				
			||||
    RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,72 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
 | 
				
			||||
 | 
				
			||||
#include <sys/ipc.h>
 | 
				
			||||
#include <sys/shm.h>
 | 
				
			||||
#include <mutex>
 | 
				
			||||
#include <string>
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/util/status.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
namespace gnn {
 | 
				
			||||
 | 
				
			||||
const int kGnnSharedMemoryId = 65;
 | 
				
			||||
 | 
				
			||||
class GraphSharedMemory {
 | 
				
			||||
 public:
 | 
				
			||||
  explicit GraphSharedMemory(int64_t memory_size, key_t memory_key);
 | 
				
			||||
  explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file);
 | 
				
			||||
 | 
				
			||||
  ~GraphSharedMemory();
 | 
				
			||||
 | 
				
			||||
  // @param uint8_t** shared_memory - shared memory address
 | 
				
			||||
  // @return Status - the status code
 | 
				
			||||
  Status CreateSharedMemory();
 | 
				
			||||
 | 
				
			||||
  // @param uint8_t** shared_memory - shared memory address
 | 
				
			||||
  // @return Status - the status code
 | 
				
			||||
  Status GetSharedMemory();
 | 
				
			||||
 | 
				
			||||
  Status DeleteSharedMemory();
 | 
				
			||||
 | 
				
			||||
  Status InsertData(const uint8_t *data, int64_t len, int64_t *offset);
 | 
				
			||||
 | 
				
			||||
  Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len);
 | 
				
			||||
 | 
				
			||||
  key_t memory_key() { return memory_key_; }
 | 
				
			||||
 | 
				
			||||
  int64_t memory_size() { return memory_size_; }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  Status SharedMemoryImpl(const int &shmflg);
 | 
				
			||||
 | 
				
			||||
  std::string mr_file_;
 | 
				
			||||
  int64_t memory_size_;
 | 
				
			||||
  key_t memory_key_;
 | 
				
			||||
  std::string memory_key_str_;
 | 
				
			||||
  uint8_t *memory_ptr_;
 | 
				
			||||
  int64_t memory_offset_;
 | 
				
			||||
  std::mutex mutex_;
 | 
				
			||||
  bool is_new_create_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace gnn
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
 | 
				
			||||
@ -0,0 +1,82 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
*/
 | 
				
			||||
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
 | 
				
			||||
 | 
				
			||||
#include <limits>
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/util/task_manager.h"
 | 
				
			||||
#include "utils/log_adapter.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
 | 
				
			||||
GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {}
 | 
				
			||||
 | 
				
			||||
GrpcAsyncServer::~GrpcAsyncServer() { Stop(); }
 | 
				
			||||
 | 
				
			||||
Status GrpcAsyncServer::Run() {
 | 
				
			||||
  std::string server_address = host_ + ":" + std::to_string(port_);
 | 
				
			||||
  grpc::ServerBuilder builder;
 | 
				
			||||
  // Default message size for gRPC is 4MB. Increase it to 2g-1
 | 
				
			||||
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
 | 
				
			||||
  builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
 | 
				
			||||
  int port_tcpip = 0;
 | 
				
			||||
  builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
 | 
				
			||||
  RETURN_IF_NOT_OK(RegisterService(&builder));
 | 
				
			||||
  cq_ = builder.AddCompletionQueue();
 | 
				
			||||
  server_ = builder.BuildAndStart();
 | 
				
			||||
  if (server_) {
 | 
				
			||||
    MS_LOG(INFO) << "Server listening on " << server_address;
 | 
				
			||||
  } else {
 | 
				
			||||
    std::string errMsg = "Fail to start server. ";
 | 
				
			||||
    if (port_tcpip != port_) {
 | 
				
			||||
      errMsg += "Unable to bind to address " + server_address + ".";
 | 
				
			||||
    }
 | 
				
			||||
    RETURN_STATUS_UNEXPECTED(errMsg);
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
Status GrpcAsyncServer::HandleRequest() {
 | 
				
			||||
  bool success;
 | 
				
			||||
  void *tag;
 | 
				
			||||
  // We loop through the grpc queue. Each connection if successful
 | 
				
			||||
  // will come back with our own tag which is an instance of CallData
 | 
				
			||||
  // and we simply call its functor. But first we need to create these instances
 | 
				
			||||
  // and inject them into the grpc queue.
 | 
				
			||||
  RETURN_IF_NOT_OK(EnqueueRequest());
 | 
				
			||||
  while (cq_->Next(&tag, &success)) {
 | 
				
			||||
    RETURN_IF_INTERRUPTED();
 | 
				
			||||
    if (success) {
 | 
				
			||||
      RETURN_IF_NOT_OK(ProcessRequest(tag));
 | 
				
			||||
    } else {
 | 
				
			||||
      MS_LOG(DEBUG) << "cq_->Next failed.";
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
  return Status::OK();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void GrpcAsyncServer::Stop() {
 | 
				
			||||
  if (server_) {
 | 
				
			||||
    server_->Shutdown();
 | 
				
			||||
  }
 | 
				
			||||
  // Always shutdown the completion queue after the server.
 | 
				
			||||
  if (cq_) {
 | 
				
			||||
    cq_->Shutdown();
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,59 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
*/
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <utility>
 | 
				
			||||
#include <vector>
 | 
				
			||||
 | 
				
			||||
#include "grpcpp/grpcpp.h"
 | 
				
			||||
#include "grpcpp/impl/codegen/async_unary_call.h"
 | 
				
			||||
 | 
				
			||||
#include "minddata/dataset/util/status.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace dataset {
 | 
				
			||||
 | 
				
			||||
/// \brief Async server base class
 | 
				
			||||
class GrpcAsyncServer {
 | 
				
			||||
 public:
 | 
				
			||||
  explicit GrpcAsyncServer(const std::string &host, int32_t port);
 | 
				
			||||
  virtual ~GrpcAsyncServer();
 | 
				
			||||
  /// \brief Brings up gRPC server
 | 
				
			||||
  /// \return none
 | 
				
			||||
  Status Run();
 | 
				
			||||
  /// \brief Entry function to handle async server request
 | 
				
			||||
  Status HandleRequest();
 | 
				
			||||
 | 
				
			||||
  void Stop();
 | 
				
			||||
 | 
				
			||||
  virtual Status RegisterService(grpc::ServerBuilder *builder) = 0;
 | 
				
			||||
 | 
				
			||||
  virtual Status EnqueueRequest() = 0;
 | 
				
			||||
 | 
				
			||||
  virtual Status ProcessRequest(void *tag) = 0;
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  int32_t port_;
 | 
				
			||||
  std::string host_;
 | 
				
			||||
  std::unique_ptr<grpc::ServerCompletionQueue> cq_;
 | 
				
			||||
  std::unique_ptr<grpc::Server> server_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace dataset
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
 | 
				
			||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue