!7911 added key value message

Merge pull request !7911 from anancds/kv-patch
pull/7911/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 7123a2c1d1

@ -100,6 +100,11 @@ message("onnx proto path is :" ${ONNX_PROTO})
ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO})
list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS})
include_directories("${CMAKE_BINARY_DIR}/ps/comm")
file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto")
ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN})
list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS})
if (ENABLE_DEBUGGER) if (ENABLE_DEBUGGER)
# debugger: compile proto files # debugger: compile proto files
include_directories("${CMAKE_BINARY_DIR}/debug/debugger") include_directories("${CMAKE_BINARY_DIR}/debug/debugger")
@ -290,7 +295,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
else () else ()
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
if (${ENABLE_IBVERBS} STREQUAL "ON") if (${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(mindspore ibverbs rdmacm) target_link_libraries(mindspore ibverbs rdmacm)
endif() endif()

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
import "google/protobuf/any.proto";
package mindspore.ps;
option optimize_for = LITE_RUNTIME;
message MessageMeta {
// hostname or ip
string hostname = 1;
// the port of this node
int32 port = 2;
// the command of this message,for example: registerheartbeatdata
int32 cmd = 3;
// the timestamp of this message
int32 timestamp = 4;
// data type of message
repeated int32 data_type = 5 [packed = true];
// message.data_size
int32 data_size = 6;
}
message CommMessage {
MessageMeta pb_meta = 1;
bytes data = 2;
}

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
message KVMessage {
repeated int32 keys = 1;
repeated float values = 2;
}
message HeartBeatMessage {
// *.*.*.*:port
repeated string host_and_port = 1;
}

@ -36,18 +36,16 @@ namespace mindspore {
namespace ps { namespace ps {
namespace comm { namespace comm {
TcpClient::TcpClient(std::string address, std::uint16_t port) TcpClient::TcpClient(const std::string &address, std::uint16_t port)
: event_base_(nullptr), : event_base_(nullptr),
event_timeout_(nullptr), event_timeout_(nullptr),
buffer_event_(nullptr), buffer_event_(nullptr),
server_address_(std::move(address)), server_address_(std::move(address)),
server_port_(port) { server_port_(port) {
message_handler_.SetCallback([this](const void *buf, size_t num) { message_handler_.SetCallback([this](const CommMessage &message) {
if (buf == nullptr) { if (message_callback_) {
if (disconnected_callback_) disconnected_callback_(*this, 200); message_callback_(*this, message);
Stop();
} }
if (message_callback_) message_callback_(*this, buf, num);
}); });
} }
@ -63,7 +61,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco
timeout_callback_ = timeout; timeout_callback_ = timeout;
} }
void TcpClient::InitTcpClient() { void TcpClient::Init() {
if (buffer_event_) { if (buffer_event_) {
return; return;
} }
@ -139,7 +137,7 @@ void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) { void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg); auto tcp_client = reinterpret_cast<TcpClient *>(arg);
tcp_client->InitTcpClient(); tcp_client->Init();
} }
void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
@ -150,10 +148,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
char read_buffer[4096]; char read_buffer[4096];
int read = 0;
while ((read = EVBUFFER_LENGTH(input)) > 0) { while (EVBUFFER_LENGTH(input) > 0) {
if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) { int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer));
if (read == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
} }
tcp_client->OnReadHandler(read_buffer, read); tcp_client->OnReadHandler(read_buffer, read);
@ -196,25 +194,38 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void
void TcpClient::Start() { void TcpClient::Start() {
MS_EXCEPTION_IF_NULL(event_base_); MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_dispatch(event_base_); int ret = event_base_dispatch(event_base_);
if (ret == 0) { MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MS_LOG(INFO) << "Event base dispatch success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
} else if (ret == 1) { << "Event base dispatch failed with no events pending or active!";
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
} else if (ret == -1) { MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; }
} else {
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; void TcpClient::StartWithNoBlock() {
} MS_LOG(INFO) << "Start tcp client with no block!";
MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
} }
void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; } void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
void TcpClient::SendMessage(const void *buf, size_t num) const { void TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_); MS_EXCEPTION_IF_NULL(buffer_event_);
if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) { uint32_t buf_size = message.ByteSizeLong();
MS_LOG(EXCEPTION) << "Event buffer add failed!"; std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
}
if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
} }
} }
} // namespace comm } // namespace comm
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

@ -23,6 +23,10 @@
#include <event2/bufferevent.h> #include <event2/bufferevent.h>
#include <functional> #include <functional>
#include <string> #include <string>
#include <memory>
#include <vector>
#include "proto/comm.pb.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@ -30,24 +34,25 @@ namespace comm {
class TcpClient { class TcpClient {
public: public:
using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>;
using OnConnected = std::function<void(const TcpClient &)>; using OnConnected = std::function<void(const TcpClient &)>;
using OnDisconnected = std::function<void(const TcpClient &, int)>; using OnDisconnected = std::function<void(const TcpClient &, int)>;
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>; using OnTimeout = std::function<void(const TcpClient &)>;
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
explicit TcpClient(std::string address, std::uint16_t port); explicit TcpClient(const std::string &address, std::uint16_t port);
virtual ~TcpClient(); virtual ~TcpClient();
std::string GetServerAddress() const; std::string GetServerAddress() const;
void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
const OnTimeout &timeout); const OnTimeout &timeout);
void InitTcpClient(); void Init();
void StartWithDelay(int seconds); void StartWithDelay(int seconds);
void Stop(); void Stop();
void ReceiveMessage(const OnMessage &cb);
void SendMessage(const void *buf, size_t num) const;
void Start(); void Start();
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
void SendMessage(const CommMessage &message) const;
protected: protected:
static void SetTcpNoDelay(const evutil_socket_t &fd); static void SetTcpNoDelay(const evutil_socket_t &fd);
@ -57,8 +62,9 @@ class TcpClient {
virtual void OnReadHandler(const void *buf, size_t num); virtual void OnReadHandler(const void *buf, size_t num);
private: private:
TcpMessageHandler message_handler_;
OnMessage message_callback_; OnMessage message_callback_;
TcpMessageHandler message_handler_;
OnConnected connected_callback_; OnConnected connected_callback_;
OnDisconnected disconnected_callback_; OnDisconnected disconnected_callback_;
OnRead read_callback_; OnRead read_callback_;
@ -71,6 +77,7 @@ class TcpClient {
std::string server_address_; std::string server_address_;
std::uint16_t server_port_; std::uint16_t server_port_;
}; };
} // namespace comm } // namespace comm
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

@ -15,6 +15,8 @@
*/ */
#include "ps/comm/tcp_message_handler.h" #include "ps/comm/tcp_message_handler.h"
#include <arpa/inet.h>
#include <iostream> #include <iostream>
#include <utility> #include <utility>
@ -22,15 +24,55 @@ namespace mindspore {
namespace ps { namespace ps {
namespace comm { namespace comm {
void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); } void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; }
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
MS_EXCEPTION_IF_NULL(buffer); MS_EXCEPTION_IF_NULL(buffer);
auto buffer_data = reinterpret_cast<const unsigned char *>(buffer);
while (num > 0) {
if (remaining_length_ == 0) {
for (int i = 0; i < 4 && num > 0; ++i) {
header_[++header_index_] = *(buffer_data + i);
--num;
if (header_index_ == 3) {
message_length_ = *reinterpret_cast<const uint32_t *>(header_);
message_length_ = ntohl(message_length_);
remaining_length_ = message_length_;
message_buffer_.reset(new unsigned char[remaining_length_]);
buffer_data += i;
break;
}
}
}
if (remaining_length_ > 0) {
uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
remaining_length_ -= copy_len;
num -= copy_len;
if (message_callback_) { int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len);
message_callback_(buffer, num); last_copy_len_ += copy_len;
buffer_data += copy_len;
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
if (remaining_length_ == 0) {
CommMessage pb_message;
pb_message.ParseFromArray(reinterpret_cast<const void *>(message_buffer_.get()), message_length_);
if (message_callback_) {
message_callback_(pb_message);
}
message_buffer_.reset();
message_buffer_ = nullptr;
header_index_ = 0;
last_copy_len_ = 0;
}
}
} }
} }
} // namespace comm } // namespace comm
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

@ -19,26 +19,43 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <string>
#include <memory> #include <memory>
#include <vector>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace comm { namespace comm {
using messageReceive = std::function<void(const void *buffer, size_t len)>; using messageReceive = std::function<void(const CommMessage &message)>;
class TcpMessageHandler { class TcpMessageHandler {
public: public:
TcpMessageHandler() = default; TcpMessageHandler()
: is_parsed_(false),
message_buffer_(nullptr),
message_length_(0),
remaining_length_(0),
header_index_(-1),
last_copy_len_(0) {}
virtual ~TcpMessageHandler() = default; virtual ~TcpMessageHandler() = default;
void SetCallback(messageReceive cb); void SetCallback(const messageReceive &cb);
void ReceiveMessage(const void *buffer, size_t num); void ReceiveMessage(const void *buffer, size_t num);
private: private:
messageReceive message_callback_; messageReceive message_callback_;
bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_;
size_t message_length_;
uint32_t remaining_length_;
char header_[4];
int header_index_;
uint32_t last_copy_len_;
}; };
} // namespace comm } // namespace comm
} // namespace ps } // namespace ps

@ -33,16 +33,12 @@ namespace mindspore {
namespace ps { namespace ps {
namespace comm { namespace comm {
void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) { void TcpConnection::InitConnection() {
MS_EXCEPTION_IF_NULL(bev); tcp_message_handler_.SetCallback([&](const CommMessage &message) {
MS_EXCEPTION_IF_NULL(server); OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
buffer_event_ = const_cast<struct bufferevent *>(bev); if (on_server_receive) {
fd_ = fd; on_server_receive(*server_, *this, message);
server_ = const_cast<TcpServer *>(server); }
tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) {
OnServerReceiveMessage message_callback = server->GetServerReceiveMessage();
if (message_callback) message_callback(*server, *this, buf, num);
}); });
} }
@ -54,11 +50,26 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const {
} }
} }
TcpServer *TcpConnection::GetServer() const { return server_; } TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); }
evutil_socket_t TcpConnection::GetFd() const { return fd_; } const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
TcpServer::TcpServer(std::string address, std::uint16_t port) void TcpConnection::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
uint32_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
}
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(),
buf_size) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
}
}
TcpServer::TcpServer(const std::string &address, std::uint16_t port)
: base_(nullptr), : base_(nullptr),
signal_event_(nullptr), signal_event_(nullptr),
listener_(nullptr), listener_(nullptr),
@ -74,7 +85,7 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
this->client_accept_ = client_accept; this->client_accept_ = client_accept;
} }
void TcpServer::InitServer() { void TcpServer::Init() {
base_ = event_base_new(); base_ = event_base_new();
MS_EXCEPTION_IF_NULL(base_); MS_EXCEPTION_IF_NULL(base_);
CommUtil::CheckIp(server_address_); CommUtil::CheckIp(server_address_);
@ -101,19 +112,26 @@ void TcpServer::InitServer() {
} }
void TcpServer::Start() { void TcpServer::Start() {
std::unique_lock<std::recursive_mutex> l(connection_mutex_); std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server!"; MS_LOG(INFO) << "Start tcp server!";
MS_EXCEPTION_IF_NULL(base_); MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_dispatch(base_); int ret = event_base_dispatch(base_);
if (ret == 0) { MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MS_LOG(INFO) << "Event base dispatch success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
} else if (ret == 1) { << "Event base dispatch failed with no events pending or active!";
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
} else if (ret == -1) { MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; }
} else {
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; void TcpServer::StartWithNoBlock() {
} std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server with no block!";
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
} }
void TcpServer::Stop() { void TcpServer::Stop() {
@ -150,6 +168,8 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co
void TcpServer::RemoveConnection(const evutil_socket_t &fd) { void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_); std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
delete connection;
connections_.erase(fd); connections_.erase(fd);
} }
@ -166,10 +186,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
return; return;
} }
TcpConnection *conn = server->onCreateConnection(); TcpConnection *conn = server->onCreateConnection(bev, fd);
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
conn->InitConnection(fd, bev, server); conn->InitConnection();
server->AddConnection(fd, conn); server->AddConnection(fd, conn);
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn)); bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
@ -177,17 +197,18 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
} }
} }
TcpConnection *TcpServer::onCreateConnection() { TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
TcpConnection *conn = nullptr; TcpConnection *conn = nullptr;
if (client_accept_) if (client_accept_) {
conn = const_cast<TcpConnection *>(client_accept_(this)); conn = const_cast<TcpConnection *>(client_accept_(*this));
else } else {
conn = new TcpConnection(); conn = new TcpConnection(bev, fd, this);
}
return conn; return conn;
} }
OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; } OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) { void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
auto server = reinterpret_cast<class TcpServer *>(data); auto server = reinterpret_cast<class TcpServer *>(data);
@ -207,9 +228,9 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
auto conn = static_cast<class TcpConnection *>(connection); auto conn = static_cast<class TcpConnection *>(connection);
struct evbuffer *buf = bufferevent_get_input(bev); struct evbuffer *buf = bufferevent_get_input(bev);
char read_buffer[4096]; char read_buffer[4096];
auto read = 0; while (EVBUFFER_LENGTH(buf) > 0) {
while ((read = EVBUFFER_LENGTH(buf)) > 0) { int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) { if (read == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
} }
conn->OnReadHandler(read_buffer, static_cast<size_t>(read)); conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
@ -219,43 +240,47 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) { void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
MS_EXCEPTION_IF_NULL(bev); MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
struct evbuffer *output = bufferevent_get_output(bev);
size_t remain = evbuffer_get_length(output);
auto conn = reinterpret_cast<TcpConnection *>(data); auto conn = reinterpret_cast<TcpConnection *>(data);
TcpServer *srv = conn->GetServer(); TcpServer *srv = conn->GetServer();
if (events & BEV_EVENT_EOF) { if (events & BEV_EVENT_EOF) {
MS_LOG(INFO) << "Event buffer end of file!";
// Notify about disconnection // Notify about disconnection
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); if (srv->client_disconnection_) {
srv->client_disconnection_(*srv, *conn);
}
// Free connection structures // Free connection structures
srv->RemoveConnection(conn->GetFd()); srv->RemoveConnection(conn->GetFd());
bufferevent_free(bev); bufferevent_free(bev);
} else if (events & BEV_EVENT_ERROR) { } else if (events & BEV_EVENT_ERROR) {
MS_LOG(ERROR) << "Event buffer remain data: " << remain;
// Free connection structures // Free connection structures
srv->RemoveConnection(conn->GetFd()); srv->RemoveConnection(conn->GetFd());
bufferevent_free(bev); bufferevent_free(bev);
// Notify about disconnection // Notify about disconnection
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); if (srv->client_disconnection_) {
srv->client_disconnection_(*srv, *conn);
}
} else { } else {
MS_LOG(ERROR) << "Unhandled event!"; MS_LOG(ERROR) << "Unhandled event!";
} }
} }
void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; } void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) { void TcpServer::SendMessage(const CommMessage &message) {
MS_EXCEPTION_IF_NULL(data);
auto mc = const_cast<TcpConnection &>(conn);
mc.SendMessage(data, num);
}
void TcpServer::SendMessage(const void *data, size_t num) {
MS_EXCEPTION_IF_NULL(data);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_); std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) { for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(*it->second, data, num); SendMessage(*it->second, message);
} }
} }
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
} // namespace comm } // namespace comm
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

@ -27,6 +27,8 @@
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <memory>
#include <vector>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ps/comm/tcp_message_handler.h" #include "ps/comm/tcp_message_handler.h"
@ -38,46 +40,49 @@ namespace comm {
class TcpServer; class TcpServer;
class TcpConnection { class TcpConnection {
public: public:
TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {} explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server)
: buffer_event_(bev), fd_(0), server_(server) {}
virtual ~TcpConnection() = default; virtual ~TcpConnection() = default;
virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server); virtual void InitConnection();
void SendMessage(const void *buffer, size_t num) const; virtual void SendMessage(const void *buffer, size_t num) const;
void SendMessage(const CommMessage &message) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes); virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const; TcpServer *GetServer() const;
evutil_socket_t GetFd() const; const evutil_socket_t &GetFd() const;
protected: protected:
TcpMessageHandler tcp_message_handler_;
struct bufferevent *buffer_event_; struct bufferevent *buffer_event_;
evutil_socket_t fd_; evutil_socket_t fd_;
TcpServer *server_; const TcpServer *server_;
TcpMessageHandler tcp_message_handler_;
}; };
using OnServerReceiveMessage = using OnServerReceiveMessage =
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>; std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>;
class TcpServer { class TcpServer {
public: public:
using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>; using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>; using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>; using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
explicit TcpServer(std::string address, std::uint16_t port); explicit TcpServer(const std::string &address, std::uint16_t port);
virtual ~TcpServer(); virtual ~TcpServer();
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
const OnAccepted &client_accept); const OnAccepted &client_accept);
void InitServer(); void Init();
void Start(); void Start();
void StartWithNoBlock();
void Stop(); void Stop();
void SendToAllClients(const char *data, size_t len); void SendToAllClients(const char *data, size_t len);
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
void RemoveConnection(const evutil_socket_t &fd); void RemoveConnection(const evutil_socket_t &fd);
void ReceiveMessage(const OnServerReceiveMessage &cb); OnServerReceiveMessage GetServerReceive() const;
static void SendMessage(const TcpConnection &conn, const void *data, size_t num); void SetMessageCallback(const OnServerReceiveMessage &cb);
void SendMessage(const void *data, size_t num); static void SendMessage(const TcpConnection &conn, const CommMessage &message);
OnServerReceiveMessage GetServerReceiveMessage() const; void SendMessage(const CommMessage &message);
protected: protected:
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
@ -85,9 +90,8 @@ class TcpServer {
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
static void ReadCallback(struct bufferevent *, void *connection); static void ReadCallback(struct bufferevent *, void *connection);
static void EventCallback(struct bufferevent *, std::int16_t events, void *server); static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
virtual TcpConnection *onCreateConnection(); virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
private:
struct event_base *base_; struct event_base *base_;
struct event *signal_event_; struct event *signal_event_;
struct evconnlistener *listener_; struct evconnlistener *listener_;
@ -101,6 +105,7 @@ class TcpServer {
std::recursive_mutex connection_mutex_; std::recursive_mutex connection_mutex_;
OnServerReceiveMessage message_callback_; OnServerReceiveMessage message_callback_;
}; };
} // namespace comm } // namespace comm
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

@ -24,6 +24,7 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <thread> #include <thread>
#include <memory>
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@ -31,7 +32,9 @@ namespace comm {
class TestHttpServer : public UT::Common { class TestHttpServer : public UT::Common {
public: public:
TestHttpServer() = default; TestHttpServer() : server_(nullptr) {}
virtual ~TestHttpServer() = default;
static void testGetHandler(std::shared_ptr<HttpMessageHandler> resp) { static void testGetHandler(std::shared_ptr<HttpMessageHandler> resp) {
std::string host = resp->GetRequestHost(); std::string host = resp->GetRequestHost();
@ -57,7 +60,7 @@ class TestHttpServer : public UT::Common {
} }
void SetUp() override { void SetUp() override {
server_ = new HttpServer("0.0.0.0", 9999); server_ = std::make_unique<HttpServer>("0.0.0.0", 9999);
OnRequestReceive http_get_func = std::bind( OnRequestReceive http_get_func = std::bind(
[](std::shared_ptr<HttpMessageHandler> resp) { [](std::shared_ptr<HttpMessageHandler> resp) {
EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1");
@ -106,7 +109,7 @@ class TestHttpServer : public UT::Common {
} }
private: private:
HttpServer *server_; std::unique_ptr<HttpServer> server_;
}; };
TEST_F(TestHttpServer, httpGetQequest) { TEST_F(TestHttpServer, httpGetQequest) {
@ -143,13 +146,13 @@ TEST_F(TestHttpServer, messageHandler) {
} }
TEST_F(TestHttpServer, portErrorNoException) { TEST_F(TestHttpServer, portErrorNoException) {
HttpServer *server_exception = new HttpServer("0.0.0.0", -1); auto server_exception = std::make_unique<HttpServer>("0.0.0.0", -1);
OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func)); EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func));
} }
TEST_F(TestHttpServer, addressException) { TEST_F(TestHttpServer, addressException) {
HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998); auto server_exception = std::make_unique<HttpServer>("12344.0.0.0", 9998);
OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception);
} }

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "common/common_test.h" #include "common/common_test.h"
#include "ps/comm/tcp_client.h" #include "ps/comm/tcp_client.h"
@ -26,19 +28,19 @@ class TestTcpClient : public UT::Common {
}; };
TEST_F(TestTcpClient, InitClientIPError) { TEST_F(TestTcpClient, InitClientIPError) {
auto client = new TcpClient("127.0.0.13543", 9000); auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);
client->ReceiveMessage(
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
ASSERT_THROW(client->InitTcpClient(), std::exception); ASSERT_THROW(client->Init(), std::exception);
} }
TEST_F(TestTcpClient, InitClientPortErrorNoException) { TEST_F(TestTcpClient, InitClientPortErrorNoException) {
auto client = new TcpClient("127.0.0.1", -1); auto client = std::make_unique<TcpClient>("127.0.0.1", -1);
client->ReceiveMessage(
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
EXPECT_NO_THROW(client->InitTcpClient()); EXPECT_NO_THROW(client->Init());
} }
} // namespace comm } // namespace comm

@ -18,6 +18,7 @@
#include "ps/comm/tcp_server.h" #include "ps/comm/tcp_server.h"
#include "common/common_test.h" #include "common/common_test.h"
#include <memory>
#include <thread> #include <thread>
namespace mindspore { namespace mindspore {
@ -25,16 +26,20 @@ namespace ps {
namespace comm { namespace comm {
class TestTcpServer : public UT::Common { class TestTcpServer : public UT::Common {
public: public:
TestTcpServer() = default; TestTcpServer() : client_(nullptr), server_(nullptr) {}
virtual ~TestTcpServer() = default;
void SetUp() override { void SetUp() override {
server_ = new TcpServer("127.0.0.1", 9000); server_ = std::make_unique<TcpServer>("127.0.0.1", 9998);
std::unique_ptr<std::thread> http_server_thread_(nullptr); std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([&]() { http_server_thread_ = std::make_unique<std::thread>([&]() {
server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) { server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE"); KVMessage kv_message;
server.SendMessage(conn, buffer, num); kv_message.ParseFromString(message.data());
EXPECT_EQ(2, kv_message.keys_size());
server.SendMessage(conn, message);
}); });
server_->InitServer(); server_->Init();
server_->Start(); server_->Start();
}); });
http_server_thread_->detach(); http_server_thread_->detach();
@ -47,21 +52,32 @@ class TestTcpServer : public UT::Common {
server_->Stop(); server_->Stop();
} }
TcpClient *client_; std::unique_ptr<TcpClient> client_;
TcpServer *server_; std::unique_ptr<TcpServer> server_;
const std::string test_message_ = "TCP_MESSAGE";
}; };
TEST_F(TestTcpServer, ServerSendMessage) { TEST_F(TestTcpServer, ServerSendMessage) {
client_ = new TcpClient("127.0.0.1", 9000); client_ = std::make_unique<TcpClient>("127.0.0.1", 9998);
std::unique_ptr<std::thread> http_client_thread(nullptr); std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() { http_client_thread = std::make_unique<std::thread>([&]() {
client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) { client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE"); KVMessage kv_message;
kv_message.ParseFromString(message.data());
EXPECT_EQ(2, kv_message.keys_size());
}); });
client_->InitTcpClient(); client_->Init();
client_->SendMessage(test_message_.c_str(), test_message_.size());
CommMessage comm_message;
KVMessage kv_message;
std::vector<int> keys{1, 2};
std::vector<int> values{3, 4};
*kv_message.mutable_keys() = {keys.begin(), keys.end()};
*kv_message.mutable_values() = {values.begin(), values.end()};
comm_message.set_data(kv_message.SerializeAsString());
client_->SendMessage(comm_message);
client_->Start(); client_->Start();
}); });
http_client_thread->detach(); http_client_thread->detach();
Loading…
Cancel
Save