|
|
|
@ -33,16 +33,12 @@ namespace mindspore {
|
|
|
|
|
namespace ps {
|
|
|
|
|
namespace comm {
|
|
|
|
|
|
|
|
|
|
void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bev);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(server);
|
|
|
|
|
buffer_event_ = const_cast<struct bufferevent *>(bev);
|
|
|
|
|
fd_ = fd;
|
|
|
|
|
server_ = const_cast<TcpServer *>(server);
|
|
|
|
|
|
|
|
|
|
tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) {
|
|
|
|
|
OnServerReceiveMessage message_callback = server->GetServerReceiveMessage();
|
|
|
|
|
if (message_callback) message_callback(*server, *this, buf, num);
|
|
|
|
|
void TcpConnection::InitConnection() {
|
|
|
|
|
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
|
|
|
|
|
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
|
|
|
|
|
if (on_server_receive) {
|
|
|
|
|
on_server_receive(*server_, *this, message);
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -54,11 +50,26 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TcpServer *TcpConnection::GetServer() const { return server_; }
|
|
|
|
|
TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); }
|
|
|
|
|
|
|
|
|
|
evutil_socket_t TcpConnection::GetFd() const { return fd_; }
|
|
|
|
|
const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
|
|
|
|
|
|
|
|
|
|
TcpServer::TcpServer(std::string address, std::uint16_t port)
|
|
|
|
|
void TcpConnection::SendMessage(const CommMessage &message) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(buffer_event_);
|
|
|
|
|
uint32_t buf_size = message.ByteSizeLong();
|
|
|
|
|
std::vector<unsigned char> serialized(buf_size);
|
|
|
|
|
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
|
|
|
|
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
|
|
|
|
|
sizeof(buf_size)) == -1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
|
|
|
|
|
}
|
|
|
|
|
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(),
|
|
|
|
|
buf_size) == -1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TcpServer::TcpServer(const std::string &address, std::uint16_t port)
|
|
|
|
|
: base_(nullptr),
|
|
|
|
|
signal_event_(nullptr),
|
|
|
|
|
listener_(nullptr),
|
|
|
|
@ -74,7 +85,7 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
|
|
|
|
|
this->client_accept_ = client_accept;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TcpServer::InitServer() {
|
|
|
|
|
void TcpServer::Init() {
|
|
|
|
|
base_ = event_base_new();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(base_);
|
|
|
|
|
CommUtil::CheckIp(server_address_);
|
|
|
|
@ -101,19 +112,26 @@ void TcpServer::InitServer() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TcpServer::Start() {
|
|
|
|
|
std::unique_lock<std::recursive_mutex> l(connection_mutex_);
|
|
|
|
|
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
|
|
|
|
MS_LOG(INFO) << "Start tcp server!";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(base_);
|
|
|
|
|
int ret = event_base_dispatch(base_);
|
|
|
|
|
if (ret == 0) {
|
|
|
|
|
MS_LOG(INFO) << "Event base dispatch success!";
|
|
|
|
|
} else if (ret == 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
|
|
|
|
|
} else if (ret == -1) {
|
|
|
|
|
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
|
|
|
|
|
}
|
|
|
|
|
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
|
|
|
|
|
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
|
|
|
|
<< "Event base dispatch failed with no events pending or active!";
|
|
|
|
|
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
|
|
|
|
|
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TcpServer::StartWithNoBlock() {
|
|
|
|
|
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
|
|
|
|
MS_LOG(INFO) << "Start tcp server with no block!";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(base_);
|
|
|
|
|
int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
|
|
|
|
|
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
|
|
|
|
|
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
|
|
|
|
|
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
|
|
|
|
|
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TcpServer::Stop() {
|
|
|
|
@ -150,6 +168,8 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co
|
|
|
|
|
|
|
|
|
|
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
|
|
|
|
|
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
|
|
|
|
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
|
|
|
|
|
delete connection;
|
|
|
|
|
connections_.erase(fd);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -166,10 +186,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TcpConnection *conn = server->onCreateConnection();
|
|
|
|
|
TcpConnection *conn = server->onCreateConnection(bev, fd);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conn);
|
|
|
|
|
|
|
|
|
|
conn->InitConnection(fd, bev, server);
|
|
|
|
|
conn->InitConnection();
|
|
|
|
|
server->AddConnection(fd, conn);
|
|
|
|
|
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
|
|
|
|
|
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
|
|
|
|
@ -177,17 +197,18 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TcpConnection *TcpServer::onCreateConnection() {
|
|
|
|
|
TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
|
|
|
|
|
TcpConnection *conn = nullptr;
|
|
|
|
|
if (client_accept_)
|
|
|
|
|
conn = const_cast<TcpConnection *>(client_accept_(this));
|
|
|
|
|
else
|
|
|
|
|
conn = new TcpConnection();
|
|
|
|
|
if (client_accept_) {
|
|
|
|
|
conn = const_cast<TcpConnection *>(client_accept_(*this));
|
|
|
|
|
} else {
|
|
|
|
|
conn = new TcpConnection(bev, fd, this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return conn;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; }
|
|
|
|
|
OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
|
|
|
|
|
|
|
|
|
|
void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
|
|
|
|
|
auto server = reinterpret_cast<class TcpServer *>(data);
|
|
|
|
@ -207,9 +228,9 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
|
|
|
|
|
auto conn = static_cast<class TcpConnection *>(connection);
|
|
|
|
|
struct evbuffer *buf = bufferevent_get_input(bev);
|
|
|
|
|
char read_buffer[4096];
|
|
|
|
|
auto read = 0;
|
|
|
|
|
while ((read = EVBUFFER_LENGTH(buf)) > 0) {
|
|
|
|
|
if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) {
|
|
|
|
|
while (EVBUFFER_LENGTH(buf) > 0) {
|
|
|
|
|
int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
|
|
|
|
|
if (read == -1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
|
|
|
|
|
}
|
|
|
|
|
conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
|
|
|
|
@ -219,43 +240,47 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
|
|
|
|
|
void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bev);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data);
|
|
|
|
|
struct evbuffer *output = bufferevent_get_output(bev);
|
|
|
|
|
size_t remain = evbuffer_get_length(output);
|
|
|
|
|
auto conn = reinterpret_cast<TcpConnection *>(data);
|
|
|
|
|
TcpServer *srv = conn->GetServer();
|
|
|
|
|
|
|
|
|
|
if (events & BEV_EVENT_EOF) {
|
|
|
|
|
MS_LOG(INFO) << "Event buffer end of file!";
|
|
|
|
|
// Notify about disconnection
|
|
|
|
|
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
|
|
|
|
|
if (srv->client_disconnection_) {
|
|
|
|
|
srv->client_disconnection_(*srv, *conn);
|
|
|
|
|
}
|
|
|
|
|
// Free connection structures
|
|
|
|
|
srv->RemoveConnection(conn->GetFd());
|
|
|
|
|
bufferevent_free(bev);
|
|
|
|
|
} else if (events & BEV_EVENT_ERROR) {
|
|
|
|
|
MS_LOG(ERROR) << "Event buffer remain data: " << remain;
|
|
|
|
|
// Free connection structures
|
|
|
|
|
srv->RemoveConnection(conn->GetFd());
|
|
|
|
|
bufferevent_free(bev);
|
|
|
|
|
|
|
|
|
|
// Notify about disconnection
|
|
|
|
|
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
|
|
|
|
|
if (srv->client_disconnection_) {
|
|
|
|
|
srv->client_disconnection_(*srv, *conn);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unhandled event!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
|
|
|
|
|
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
|
|
|
|
|
|
|
|
|
|
void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data);
|
|
|
|
|
auto mc = const_cast<TcpConnection &>(conn);
|
|
|
|
|
mc.SendMessage(data, num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TcpServer::SendMessage(const void *data, size_t num) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data);
|
|
|
|
|
void TcpServer::SendMessage(const CommMessage &message) {
|
|
|
|
|
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
|
|
|
|
|
|
|
|
|
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
|
|
|
|
|
SendMessage(*it->second, data, num);
|
|
|
|
|
SendMessage(*it->second, message);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
|
|
|
|
|
|
|
|
|
|
} // namespace comm
|
|
|
|
|
} // namespace ps
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|