Custom data transmission format

pull/11477/head
chendongsheng 4 years ago
parent 14a6713d08
commit c7fe82b43d

File diff suppressed because it is too large Load Diff

@ -25,6 +25,7 @@
#include <unordered_map> #include <unordered_map>
#include "ps/core/node.h" #include "ps/core/node.h"
#include "ps/core/message.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@ -34,53 +35,63 @@ class AbstractNode : public Node {
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
~AbstractNode() override = default; ~AbstractNode() override = default;
typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message); typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
bool Broadcast(const enum NodeRole &node_role, const CommMessage &message, using DataPtr = std::shared_ptr<unsigned char>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message); void set_event_callback(const OnNodeEventMessage &on_node_event_message);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output, bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message); uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
CommMessage *output); void **output, size_t *size);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected: protected:
void Register(const std::shared_ptr<TcpClient> &client); void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void UpdateSchedulerTime(); void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const; bool CheckSchedulerTimeout() const;
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout); bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler(); bool InitClientToScheduler();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message); bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
void ProcessSendDataResp(const CommMessage &message); const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size);
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void RunMessageCallback(const uint64_t &request_id); void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
void NotifyMessageArrival(const CommMessage &message); void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback); void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void RunReceiveCallback(const CommMessage &message);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id); uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler(); void InitCommandHandler();
uint64_t AddMessageTrack(const uint32_t &expected_response);
bool CheckMessageTrack(const uint64_t &request_id);
std::unique_ptr<std::thread> heart_beat_thread_; std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_; std::unique_ptr<std::thread> client_to_scheduler_thread_;
@ -98,15 +109,16 @@ class AbstractNode : public Node {
std::mutex message_tracker_mutex_; std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_; std::condition_variable message_tracker_cond_;
// the key is: request_id, the value is:<rank_id, CommMessage> // the key is: request_id, the value is: <rank_id, RecvMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_; std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;
std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_;
std::mutex receive_messages_mutex_; std::mutex receive_messages_mutex_;
// the key is: request_id // the key is: request_id
std::unordered_map<uint64_t, MessageCallback> message_callbacks_; std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_; std::mutex message_callbacks_mutex_;
// the key is <rank_id, rank_request_id> // the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_; std::map<std::pair<uint32_t, uint64_t>, std::shared_ptr<std::vector<unsigned char>>> received_data_;
std::mutex receive_callbacks_mutex_; std::mutex receive_callbacks_mutex_;
// the key is <rank_id, rank_request_id> // the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_; std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
#define MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
#include <string>
#include <memory>
namespace mindspore {
namespace ps {
namespace core {
enum class Protos : uint32_t { RAW = 0, PROTOBUF = 1, FLATBUFFERS = 2 };
enum class Command {
TERMINATE = 0,
REGISTER = 1,
HEARTBEAT = 2,
SEND_DATA = 3,
FETCH_SERVER = 4,
FINISH = 5,
COLLECTIVE_SEND_DATA = 6
};
enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 };
struct MessageHeader {
Protos message_proto_ = Protos::RAW;
uint32_t message_meta_length_ = 0;
uint64_t message_length_ = 0;
};
struct CommandMeta {
// the command of this message,for example: register,heartbeat,data
Command cmd;
// the request id of this message
uint64_t request_id;
// the role of the current node: worker,server,scheduler
Role role;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32_t rank_id = 4;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_

@ -15,7 +15,6 @@
*/ */
syntax = "proto3"; syntax = "proto3";
import "google/protobuf/any.proto";
package mindspore.ps.core; package mindspore.ps.core;
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
@ -44,6 +43,8 @@ message MessageMeta {
NodeRole role = 3; NodeRole role = 3;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32 rank_id = 4; int32 rank_id = 4;
// User-defined commands
int32 user_cmd = 5;
} }
message RegisterMessage { message RegisterMessage {
@ -76,6 +77,10 @@ message HeartbeatRespMessage {
bool is_node_timeout = 4; bool is_node_timeout = 4;
} }
message FetchServersMessage {
string node_id = 1;
}
message FetchServersRespMessage { message FetchServersRespMessage {
repeated ServersMeta servers_meta = 1; repeated ServersMeta servers_meta = 1;
} }
@ -95,6 +100,4 @@ message FinishMessage {
message CommMessage { message CommMessage {
MessageMeta pb_meta = 1; MessageMeta pb_meta = 1;
bytes data = 2; bytes data = 2;
// User-defined commands
bytes user_cmd = 3;
} }

@ -38,9 +38,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
} }
void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) { std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
HeartbeatMessage heartbeat_message; HeartbeatMessage heartbeat_message;
heartbeat_message.ParseFromString(message->data()); heartbeat_message.ParseFromArray(data, size);
node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
@ -60,10 +64,8 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
*comm_message->mutable_pb_meta() = {message->pb_meta()}; heartbeat_resp_message.ByteSizeLong());
comm_message->set_data(heartbeat_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
} }
void SchedulerNode::Initialize() { void SchedulerNode::Initialize() {
@ -89,12 +91,13 @@ void SchedulerNode::CreateTcpServer() {
std::string scheduler_host = ClusterConfig::scheduler_host(); std::string scheduler_host = ClusterConfig::scheduler_host();
uint32_t scheduler_port = ClusterConfig::scheduler_port(); uint32_t scheduler_port = ClusterConfig::scheduler_port();
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port); server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
if (handlers_.count(message->pb_meta().cmd()) == 0) { const Protos &protos, const void *data, size_t size) {
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
} }
const auto &handler_ptr = handlers_[message->pb_meta().cmd()]; const auto &handler_ptr = handlers_[meta->cmd()];
(this->*handler_ptr)(server_, conn, message); (this->*handler_ptr)(server_, conn, meta, data, size);
}); });
server_->Init(); server_->Init();
@ -106,10 +109,14 @@ void SchedulerNode::CreateTcpServer() {
} }
void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) { std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
MS_LOG(INFO) << "The scheduler process a register message!"; MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message; RegisterMessage register_message;
register_message.ParseFromString(message->data()); register_message.ParseFromArray(data, size);
// assign worker node and server node rank id // assign worker node and server node rank id
int rank_id = node_manager_.NextRankId(register_message); int rank_id = node_manager_.NextRankId(register_message);
@ -123,32 +130,32 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
register_resp_message.set_node_id(node_id); register_resp_message.set_node_id(node_id);
register_resp_message.set_rank_id(rank_id); register_resp_message.set_rank_id(rank_id);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
*comm_message->mutable_pb_meta() = {message->pb_meta()}; register_resp_message.ByteSizeLong());
comm_message->set_data(register_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
} }
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) { std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
FinishMessage finish_message; FinishMessage finish_message;
finish_message.ParseFromString(message->data()); finish_message.ParseFromArray(data, size);
node_manager_.AddFinishNode(finish_message); node_manager_.AddFinishNode(finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
server->SendMessage(conn, message); server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
} }
void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) { std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
FetchServersRespMessage fetch_servers_message; FetchServersRespMessage fetch_servers_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(),
*comm_message->mutable_pb_meta() = {message->pb_meta()}; fetch_servers_message.ByteSizeLong());
comm_message->set_data(fetch_servers_message.SerializeAsString());
server->SendMessage(conn, comm_message);
} }
void SchedulerNode::StartUpdateClusterStateTimer() { void SchedulerNode::StartUpdateClusterStateTimer() {

@ -36,13 +36,14 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
class SchedulerNode : public Node { class SchedulerNode : public Node {
public: public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
~SchedulerNode() override; ~SchedulerNode() override;
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message); std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override; bool Stop() override;
@ -53,14 +54,14 @@ class SchedulerNode : public Node {
void InitCommandHandler(); void InitCommandHandler();
void CreateTcpServer(); void CreateTcpServer();
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message); std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message); std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void StartUpdateClusterStateTimer(); void StartUpdateClusterStateTimer();
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message); std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message); std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
std::shared_ptr<TcpServer> server_; std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_; std::unique_ptr<std::thread> scheduler_thread_;

@ -46,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) {
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data,
size_t size) {
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(meta);
message->mutable_pb_meta()->set_role(node_info_.node_role_); MS_EXCEPTION_IF_NULL(data);
message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_); meta->set_role(node_info_.node_role_);
const MessageMeta &message_meta = message->pb_meta(); meta->set_rank_id(node_info_.rank_id_);
const uint64_t request_id = message_meta.request_id();
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
server_->SendMessage(conn, message); server_->SendMessage(conn, meta, Protos::RAW, data.get(), size);
} }
void ServerNode::CreateTcpServer() { void ServerNode::CreateTcpServer() {
@ -63,17 +63,18 @@ void ServerNode::CreateTcpServer() {
std::string server_ip; std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0); server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
switch (message->pb_meta().cmd()) { const Protos &protos, const void *data, size_t size) {
switch (meta->cmd()) {
case NodeCommand::SEND_DATA: case NodeCommand::SEND_DATA:
ProcessSendData(conn, message); ProcessSendData(conn, meta, protos, data, size);
break; break;
case NodeCommand::COLLECTIVE_SEND_DATA: case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(conn, message); ProcessCollectiveSendData(conn, meta, data, size);
RunReceiveCallback(*message); RunReceiveCallback(meta, protos, data, size);
break; break;
default: default:
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
} }
}); });
server_->Init(); server_->Init();
@ -99,18 +100,24 @@ void ServerNode::Initialize() {
MS_LOG(INFO) << "Server node init client successful!"; MS_LOG(INFO) << "Server node init client successful!";
} }
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(meta);
request_handler_(conn, message); MS_EXCEPTION_IF_NULL(data);
std::shared_ptr<unsigned char> res(new unsigned char[size]);
int ret = memcpy_s(res.get(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
request_handler_(conn, meta, res, size);
} }
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(meta);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>(); server_->SendMessage(conn, meta, Protos::RAW, data, size);
*comm_message->mutable_pb_meta() = {message->pb_meta()};
server_->SendMessage(conn, comm_message);
} }
bool ServerNode::Stop() { bool ServerNode::Stop() {

@ -23,6 +23,7 @@
#include <string> #include <string>
#include <thread> #include <thread>
#include <utility> #include <utility>
#include <vector>
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h" #include "ps/core/tcp_client.h"
@ -41,16 +42,19 @@ class ServerNode : public AbstractNode {
bool Stop() override; bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
DataPtr data, size_t size)>;
void set_handler(const RequestHandler &handler); void set_handler(const RequestHandler &handler);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data, size_t size);
private: private:
void CreateTcpServer(); void CreateTcpServer();
void Initialize(); void Initialize();
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); const void *data, size_t size);
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size);
std::shared_ptr<TcpServer> server_; std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_; std::unique_ptr<std::thread> server_thread_;

@ -46,11 +46,12 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
server_port_(port), server_port_(port),
is_stop_(true), is_stop_(true),
is_connected_(false) { is_connected_(false) {
message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) { message_handler_.SetCallback(
if (message_callback_) { [this](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
message_callback_(*this, *message); if (message_callback_) {
} message_callback_(meta, protos, data, size);
}); }
});
} }
TcpClient::~TcpClient() { TcpClient::~TcpClient() {
@ -189,7 +190,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
void TcpClient::OnReadHandler(const void *buf, size_t num) { void TcpClient::OnReadHandler(const void *buf, size_t num) {
MS_EXCEPTION_IF_NULL(buf); MS_EXCEPTION_IF_NULL(buf);
if (read_callback_) { if (read_callback_) {
read_callback_(*this, buf, num); read_callback_(buf, num);
} }
message_handler_.ReceiveMessage(buf, num); message_handler_.ReceiveMessage(buf, num);
} }
@ -198,7 +199,7 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg); auto tcp_client = reinterpret_cast<TcpClient *>(arg);
if (tcp_client->on_timer_callback_) { if (tcp_client->on_timer_callback_) {
tcp_client->on_timer_callback_(*tcp_client); tcp_client->on_timer_callback_();
} }
} }
@ -245,7 +246,7 @@ void TcpClient::Start() {
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!"; << "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
} }
void TcpClient::StartWithNoBlock() { void TcpClient::StartWithNoBlock() {
@ -256,7 +257,7 @@ void TcpClient::StartWithNoBlock() {
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
} }
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
@ -265,14 +266,49 @@ bool TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_); MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_lock(buffer_event_); bufferevent_lock(buffer_event_);
bool res = true; bool res = true;
size_t buf_size = message.ByteSizeLong(); size_t buf_size = IntToUint(message.ByteSizeLong());
std::vector<unsigned char> serialized(buf_size); uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); Messageheader header;
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { header.message_proto_ = Protos::PROTOBUF;
header.message_length_ = buf_size;
header.message_meta_length_ = meta_size;
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}
bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
bufferevent_lock(buffer_event_);
bool res = true;
Messageheader header;
header.message_proto_ = protos;
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
header.message_length_ = size + header.message_meta_length_;
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!"; MS_LOG(ERROR) << "Event buffer add header failed!";
res = false; res = false;
} }
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, data, size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false; res = false;
} }

@ -42,10 +42,10 @@ class TcpClient {
public: public:
using OnConnected = std::function<void()>; using OnConnected = std::function<void()>;
using OnDisconnected = std::function<void()>; using OnDisconnected = std::function<void()>;
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; using OnRead = std::function<void(const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>; using OnTimeout = std::function<void()>;
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; using OnMessage = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
using OnTimer = std::function<void(const TcpClient &)>; using OnTimer = std::function<void()>;
explicit TcpClient(const std::string &address, std::uint16_t port); explicit TcpClient(const std::string &address, std::uint16_t port);
virtual ~TcpClient(); virtual ~TcpClient();
@ -61,6 +61,7 @@ class TcpClient {
void StartWithNoBlock(); void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb); void SetMessageCallback(const OnMessage &cb);
bool SendMessage(const CommMessage &message) const; bool SendMessage(const CommMessage &message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void StartTimer(const uint32_t &time); void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer); void set_timer_callback(const OnTimer &timer);
const event_base &eventbase(); const event_base &eventbase();

@ -35,8 +35,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
header_[++header_index_] = *(buffer_data + i); header_[++header_index_] = *(buffer_data + i);
--num; --num;
if (header_index_ == kHeaderLen - 1) { if (header_index_ == kHeaderLen - 1) {
message_length_ = *reinterpret_cast<const size_t *>(header_); message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_);
remaining_length_ = message_length_; message_header_.message_meta_length_ =
*reinterpret_cast<const uint32_t *>(header_ + sizeof(message_header_.message_proto_));
message_header_.message_length_ = *reinterpret_cast<const size_t *>(
header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_));
remaining_length_ = message_header_.message_length_;
message_buffer_.reset(new unsigned char[remaining_length_]); message_buffer_.reset(new unsigned char[remaining_length_]);
buffer_data += (i + 1); buffer_data += (i + 1);
break; break;
@ -57,10 +61,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
} }
if (remaining_length_ == 0) { if (remaining_length_ == 0) {
std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>();
pb_message->ParseFromArray(message_buffer_.get(), message_length_);
if (message_callback_) { if (message_callback_) {
message_callback_(pb_message); std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>();
pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_);
message_callback_(pb_message, message_header_.message_proto_,
message_buffer_.get() + message_header_.message_meta_length_,
message_header_.message_length_ - message_header_.message_meta_length_);
} }
message_buffer_.reset(); message_buffer_.reset();
message_buffer_ = nullptr; message_buffer_ = nullptr;

@ -24,24 +24,20 @@
#include <vector> #include <vector>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ps/core/message.h"
#include "proto/comm.pb.h" #include "proto/comm.pb.h"
#include "proto/ps.pb.h" #include "proto/ps.pb.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>; using messageReceive = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
constexpr int kHeaderLen = 8; constexpr int kHeaderLen = 16;
class TcpMessageHandler { class TcpMessageHandler {
public: public:
TcpMessageHandler() TcpMessageHandler()
: is_parsed_(false), : is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {}
message_buffer_(nullptr),
message_length_(0),
remaining_length_(0),
header_index_(-1),
last_copy_len_(0) {}
virtual ~TcpMessageHandler() = default; virtual ~TcpMessageHandler() = default;
void SetCallback(const messageReceive &cb); void SetCallback(const messageReceive &cb);
@ -51,11 +47,12 @@ class TcpMessageHandler {
messageReceive message_callback_; messageReceive message_callback_;
bool is_parsed_; bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_; std::unique_ptr<unsigned char> message_buffer_;
size_t message_length_;
size_t remaining_length_; size_t remaining_length_;
char header_[8]; char header_[16];
int header_index_; int header_index_;
size_t last_copy_len_; size_t last_copy_len_;
MessageHeader message_header_;
std::string mBuffer;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps

@ -54,13 +54,39 @@ bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
bufferevent_lock(buffer_event_); bufferevent_lock(buffer_event_);
bool res = true; bool res = true;
size_t buf_size = message->ByteSizeLong(); size_t buf_size = message->ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message->SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!"; MS_LOG(ERROR) << "Event buffer add header failed!";
res = false; res = false;
} }
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}
bool TcpConnection::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
bufferevent_lock(buffer_event_);
bool res = true;
Messageheader header;
header.message_proto_ = protos;
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
header.message_length_ = size + header.message_meta_length_;
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, data, size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false; res = false;
} }
@ -158,7 +184,7 @@ void TcpServer::Start() {
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!"; << "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
} }
void TcpServer::StartWithNoBlock() { void TcpServer::StartWithNoBlock() {
@ -169,7 +195,7 @@ void TcpServer::StartWithNoBlock() {
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
} }
void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
@ -260,10 +286,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
server->AddConnection(fd, conn); server->AddConnection(fd, conn);
conn->InitConnection([=](std::shared_ptr<CommMessage> message) { conn->InitConnection([=](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
OnServerReceiveMessage on_server_receive = server->GetServerReceive(); OnServerReceiveMessage on_server_receive = server->GetServerReceive();
if (on_server_receive) { if (on_server_receive) {
on_server_receive(conn, message); on_server_receive(conn, meta, protos, data, size);
} }
}); });
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
@ -274,6 +300,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
} }
std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
MS_EXCEPTION_IF_NULL(bev);
std::shared_ptr<TcpConnection> conn = nullptr; std::shared_ptr<TcpConnection> conn = nullptr;
if (client_accept_) { if (client_accept_) {
conn = (client_accept_(*this)); conn = (client_accept_(*this));
@ -367,9 +394,17 @@ bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr
return conn->SendMessage(message); return conn->SendMessage(message);
} }
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
return conn->SendMessage(meta, protos, data, size);
}
void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) { void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(message);
std::lock_guard<std::mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) { for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(it->second, message); SendMessage(it->second, message);

@ -36,7 +36,6 @@
#include "ps/core/tcp_message_handler.h" #include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h" #include "utils/convert_utils_base.h"
namespace mindspore { namespace mindspore {
@ -55,6 +54,7 @@ class TcpConnection {
virtual void InitConnection(const messageReceive &callback); virtual void InitConnection(const messageReceive &callback);
virtual void SendMessage(const void *buffer, size_t num) const; virtual void SendMessage(const void *buffer, size_t num) const;
bool SendMessage(std::shared_ptr<CommMessage> message) const; bool SendMessage(std::shared_ptr<CommMessage> message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes); virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const; TcpServer *GetServer() const;
const evutil_socket_t &GetFd() const; const evutil_socket_t &GetFd() const;
@ -69,7 +69,8 @@ class TcpConnection {
}; };
using OnServerReceiveMessage = using OnServerReceiveMessage =
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>; std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size)>;
class TcpServer { class TcpServer {
public: public:
@ -100,6 +101,8 @@ class TcpServer {
OnServerReceiveMessage GetServerReceive() const; OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb); void SetMessageCallback(const OnServerReceiveMessage &cb);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t sizee);
void SendMessage(std::shared_ptr<CommMessage> message); void SendMessage(std::shared_ptr<CommMessage> message);
uint16_t BoundPort() const; uint16_t BoundPort() const;
std::string BoundIp() const; std::string BoundIp() const;

@ -30,7 +30,12 @@ class TestTcpClient : public UT::Common {
TEST_F(TestTcpClient, InitClientIPError) { TEST_F(TestTcpClient, InitClientIPError) {
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000); auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
message.ParseFromArray(data, size);
client->SendMessage(message);
});
ASSERT_THROW(client->Init(), std::exception); ASSERT_THROW(client->Init(), std::exception);
} }
@ -38,10 +43,15 @@ TEST_F(TestTcpClient, InitClientIPError) {
TEST_F(TestTcpClient, InitClientPortErrorNoException) { TEST_F(TestTcpClient, InitClientPortErrorNoException) {
auto client = std::make_unique<TcpClient>("127.0.0.1", -1); auto client = std::make_unique<TcpClient>("127.0.0.1", -1);
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
message.ParseFromArray(data, size);
client->SendMessage(message);
});
EXPECT_NO_THROW(client->Init()); EXPECT_NO_THROW(client->Init());
} }
} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

@ -33,130 +33,144 @@ class TestTcpMessageHandler : public UT::Common {
void TearDown() override {} void TearDown() override {}
}; };
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 1000);
});
std::string data(1000, 'a'); std::string data(1000, 'a');
CommMessage message;
message.set_data(data); char result[1018];
size_t buf_size = message.ByteSizeLong();
char result[1011]; MessageMeta meta;
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);
MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
std::vector<char> serialized(buf_size); memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
handler.ReceiveMessage(result, buf_size + kHeaderLen); handler.ReceiveMessage(result, 1018);
} }
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data_16Header_2meta_1000Data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); }); handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 1000);
});
std::string data(1000, 'a'); std::string data(1000, 'a');
CommMessage message;
message.set_data(data); char result[2036];
size_t buf_size = message.ByteSizeLong();
char result[2022] = {0}; MessageMeta meta;
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); meta.set_request_id(1);
if (ret != 0) { EXPECT_EQ(meta.ByteSizeLong(), 2);
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} MessageHeader header;
std::vector<char> serialized(buf_size); header.message_proto_ = Protos::RAW;
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); header.message_meta_length_ = meta.ByteSizeLong();
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); header.message_length_ = data.length() + meta.ByteSizeLong();
if (ret != 0) { int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2); memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() + data.length(), meta.ByteSizeLong(),
meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() * 2 + data.length(), data.length(), data.data(),
data.length());
handler.ReceiveMessage(result, 2036);
} }
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { TEST_F(TestTcpMessageHandler, 16header_2meta_4070data_8header_8header_2meta_4070data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); }); handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 4070);
});
std::string data(4070, 'a');
std::string data(4081, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0}; char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4); MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);
MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), 8, &header, 8);
handler.ReceiveMessage(result, 4096); handler.ReceiveMessage(result, 4096);
auto temp = reinterpret_cast<char *>(&buf_size); auto temp = reinterpret_cast<char *>(&header);
ret = memcpy_s(result, 4, temp + 4, 4); memcpy_s(result, 8, temp + 8, 8);
if (ret != 0) { memcpy_s(result + 8, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; memcpy_s(result + 8 + 2, data.length(), data.data(), data.length());
}
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4088); handler.ReceiveMessage(result, 4080);
} }
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { TEST_F(TestTcpMessageHandler, 16Header_2meta_4062Data_16Header_2meta_4062_data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); }); handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 4062);
});
std::string data(4062, 'a');
std::string data(4077, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0}; char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);
MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);
handler.ReceiveMessage(result, 4096); handler.ReceiveMessage(result, 4096);
ret = memcpy_s(result, buf_size, serialized.data(), buf_size); memcpy_s(result, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
if (ret != 0) { memcpy_s(result + meta.ByteSizeLong(), data.length(), data.data(), data.length());
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4080); handler.ReceiveMessage(result, 4064);
} }
} // namespace core } // namespace core
} // namespace ps } // namespace ps

@ -33,11 +33,12 @@ class TestTcpServer : public UT::Common {
server_ = std::make_unique<TcpServer>("127.0.0.1", 0); server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
std::unique_ptr<std::thread> http_server_thread_(nullptr); std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([=]() { http_server_thread_ = std::make_unique<std::thread>([=]() {
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
KVMessage kv_message; KVMessage kv_message;
kv_message.ParseFromString(message->data()); kv_message.ParseFromArray(data, size);
EXPECT_EQ(2, kv_message.keys_size()); EXPECT_EQ(2, kv_message.keys_size());
server_->SendMessage(conn, message); server_->SendMessage(conn, meta, protos, data, size);
}); });
server_->Init(); server_->Init();
server_->Start(); server_->Start();
@ -61,23 +62,24 @@ TEST_F(TestTcpServer, ServerSendMessage) {
std::cout << server_->BoundPort() << std::endl; std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr); std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() { http_client_thread = std::make_unique<std::thread>([&]() {
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client_->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
KVMessage kv_message; KVMessage message;
kv_message.ParseFromString(message.data()); message.ParseFromArray(data, size);
EXPECT_EQ(2, kv_message.keys_size()); EXPECT_EQ(2, message.keys_size());
}); });
client_->Init(); client_->Init();
CommMessage comm_message;
KVMessage kv_message; KVMessage kv_message;
std::vector<int> keys{1, 2}; std::vector<int> keys{1, 2};
std::vector<int> values{3, 4}; std::vector<int> values{3, 4};
*kv_message.mutable_keys() = {keys.begin(), keys.end()}; *kv_message.mutable_keys() = {keys.begin(), keys.end()};
*kv_message.mutable_values() = {values.begin(), values.end()}; *kv_message.mutable_values() = {values.begin(), values.end()};
comm_message.set_data(kv_message.SerializeAsString()); auto message_meta = std::make_shared<MessageMeta>();
client_->SendMessage(comm_message); message_meta->set_cmd(NodeCommand::SEND_DATA);
client_->SendMessage(message_meta, Protos::RAW, kv_message.SerializeAsString().data(), kv_message.ByteSizeLong());
client_->Start(); client_->Start();
}); });

Loading…
Cancel
Save