From 87edbdf7208c53d00dc627342dac5c484061bd66 Mon Sep 17 00:00:00 2001 From: anancds Date: Tue, 27 Oct 2020 17:25:48 +0800 Subject: [PATCH] added protobuf message --- mindspore/ccsrc/CMakeLists.txt | 7 +- mindspore/ccsrc/ps/comm/protos/comm.proto | 42 +++++++ mindspore/ccsrc/ps/comm/protos/ps.proto | 25 ++++ mindspore/ccsrc/ps/comm/tcp_client.cc | 59 +++++---- mindspore/ccsrc/ps/comm/tcp_client.h | 19 ++- .../ccsrc/ps/comm/tcp_message_handler.cc | 48 ++++++- mindspore/ccsrc/ps/comm/tcp_message_handler.h | 23 +++- mindspore/ccsrc/ps/comm/tcp_server.cc | 119 +++++++++++------- mindspore/ccsrc/ps/comm/tcp_server.h | 41 +++--- tests/ut/cpp/ps/comm/http_server_test.cc | 13 +- tests/ut/cpp/ps/comm/tcp_client_tests.cc | 18 +-- ..._server_tests.cc => tcp_pb_server_test.cc} | 44 ++++--- 12 files changed, 329 insertions(+), 129 deletions(-) create mode 100644 mindspore/ccsrc/ps/comm/protos/comm.proto create mode 100644 mindspore/ccsrc/ps/comm/protos/ps.proto rename tests/ut/cpp/ps/comm/{tcp_server_tests.cc => tcp_pb_server_test.cc} (56%) diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 756767c799..38c325a127 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -100,6 +100,11 @@ message("onnx proto path is :" ${ONNX_PROTO}) ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) +include_directories("${CMAKE_BINARY_DIR}/ps/comm") +file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto") +ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) +list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS}) + if (ENABLE_DEBUGGER) # debugger: compile proto files include_directories("${CMAKE_BINARY_DIR}/debug/debugger") @@ -290,7 +295,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows") target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) else () if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) - target_link_libraries(mindspore mindspore::pslite mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) + target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) if (${ENABLE_IBVERBS} STREQUAL "ON") target_link_libraries(mindspore ibverbs rdmacm) endif() diff --git a/mindspore/ccsrc/ps/comm/protos/comm.proto b/mindspore/ccsrc/ps/comm/protos/comm.proto new file mode 100644 index 0000000000..653af8edfe --- /dev/null +++ b/mindspore/ccsrc/ps/comm/protos/comm.proto @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +import "google/protobuf/any.proto"; +package mindspore.ps; +option optimize_for = LITE_RUNTIME; + +message MessageMeta { + // hostname or ip + string hostname = 1; + // the port of this node + int32 port = 2; + // the command of this message,for example: register、heartbeat、data + int32 cmd = 3; + // the timestamp of this message + int32 timestamp = 4; + // data type of message + repeated int32 data_type = 5 [packed = true]; + // message.data_size + int32 data_size = 6; +} + + +message CommMessage { + MessageMeta pb_meta = 1; + bytes data = 2; +} + diff --git a/mindspore/ccsrc/ps/comm/protos/ps.proto b/mindspore/ccsrc/ps/comm/protos/ps.proto new file mode 100644 index 0000000000..9cee1712bf --- /dev/null +++ b/mindspore/ccsrc/ps/comm/protos/ps.proto @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +message KVMessage { + repeated int32 keys = 1; + repeated float values = 2; +} + +message HeartBeatMessage { + // *.*.*.*:port + repeated string host_and_port = 1; +} \ No newline at end of file diff --git a/mindspore/ccsrc/ps/comm/tcp_client.cc b/mindspore/ccsrc/ps/comm/tcp_client.cc index 3b28bfbc6c..b55aa18af2 100644 --- a/mindspore/ccsrc/ps/comm/tcp_client.cc +++ b/mindspore/ccsrc/ps/comm/tcp_client.cc @@ -36,18 +36,16 @@ namespace mindspore { namespace ps { namespace comm { -TcpClient::TcpClient(std::string address, std::uint16_t port) +TcpClient::TcpClient(const std::string &address, std::uint16_t port) : event_base_(nullptr), event_timeout_(nullptr), buffer_event_(nullptr), server_address_(std::move(address)), server_port_(port) { - message_handler_.SetCallback([this](const void *buf, size_t num) { - if (buf == nullptr) { - if (disconnected_callback_) disconnected_callback_(*this, 200); - Stop(); + message_handler_.SetCallback([this](const CommMessage &message) { + if (message_callback_) { + message_callback_(*this, message); } - if (message_callback_) message_callback_(*this, buf, num); }); } @@ -63,7 +61,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco timeout_callback_ = timeout; } -void TcpClient::InitTcpClient() { +void TcpClient::Init() { if (buffer_event_) { return; } @@ -139,7 +137,7 @@ void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) { void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) { MS_EXCEPTION_IF_NULL(arg); auto tcp_client = reinterpret_cast(arg); - tcp_client->InitTcpClient(); + tcp_client->Init(); } void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { @@ -150,10 +148,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { MS_EXCEPTION_IF_NULL(input); char read_buffer[4096]; - int read = 0; - while ((read = EVBUFFER_LENGTH(input)) > 0) { - if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) { + while (EVBUFFER_LENGTH(input) > 0) { + int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer)); + if (read == -1) { MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; } tcp_client->OnReadHandler(read_buffer, read); @@ -196,25 +194,38 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void void TcpClient::Start() { MS_EXCEPTION_IF_NULL(event_base_); int ret = event_base_dispatch(event_base_); - if (ret == 0) { - MS_LOG(INFO) << "Event base dispatch success!"; - } else if (ret == 1) { - MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; - } else if (ret == -1) { - MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; - } else { - MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; - } + MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; + MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) + << "Event base dispatch failed with no events pending or active!"; + MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; +} + +void TcpClient::StartWithNoBlock() { + MS_LOG(INFO) << "Start tcp client with no block!"; + MS_EXCEPTION_IF_NULL(event_base_); + int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK); + MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; + MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; + MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; } -void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; } +void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } -void TcpClient::SendMessage(const void *buf, size_t num) const { +void TcpClient::SendMessage(const CommMessage &message) const { MS_EXCEPTION_IF_NULL(buffer_event_); - if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) { - MS_LOG(EXCEPTION) << "Event buffer add failed!"; + uint32_t buf_size = message.ByteSizeLong(); + std::vector serialized(buf_size); + message.SerializeToArray(serialized.data(), static_cast(buf_size)); + if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { + MS_LOG(EXCEPTION) << "Event buffer add header failed!"; + } + if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) { + MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; } } + } // namespace comm } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_client.h b/mindspore/ccsrc/ps/comm/tcp_client.h index 49d7478dab..2108e1db85 100644 --- a/mindspore/ccsrc/ps/comm/tcp_client.h +++ b/mindspore/ccsrc/ps/comm/tcp_client.h @@ -23,6 +23,10 @@ #include #include #include +#include +#include + +#include "proto/comm.pb.h" namespace mindspore { namespace ps { @@ -30,24 +34,25 @@ namespace comm { class TcpClient { public: - using OnMessage = std::function; using OnConnected = std::function; using OnDisconnected = std::function; using OnRead = std::function; using OnTimeout = std::function; + using OnMessage = std::function; - explicit TcpClient(std::string address, std::uint16_t port); + explicit TcpClient(const std::string &address, std::uint16_t port); virtual ~TcpClient(); std::string GetServerAddress() const; void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, const OnTimeout &timeout); - void InitTcpClient(); + void Init(); void StartWithDelay(int seconds); void Stop(); - void ReceiveMessage(const OnMessage &cb); - void SendMessage(const void *buf, size_t num) const; void Start(); + void StartWithNoBlock(); + void SetMessageCallback(const OnMessage &cb); + void SendMessage(const CommMessage &message) const; protected: static void SetTcpNoDelay(const evutil_socket_t &fd); @@ -57,8 +62,9 @@ class TcpClient { virtual void OnReadHandler(const void *buf, size_t num); private: - TcpMessageHandler message_handler_; OnMessage message_callback_; + TcpMessageHandler message_handler_; + OnConnected connected_callback_; OnDisconnected disconnected_callback_; OnRead read_callback_; @@ -71,6 +77,7 @@ class TcpClient { std::string server_address_; std::uint16_t server_port_; }; + } // namespace comm } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_message_handler.cc b/mindspore/ccsrc/ps/comm/tcp_message_handler.cc index 5755802346..97285bdcc9 100644 --- a/mindspore/ccsrc/ps/comm/tcp_message_handler.cc +++ b/mindspore/ccsrc/ps/comm/tcp_message_handler.cc @@ -15,6 +15,8 @@ */ #include "ps/comm/tcp_message_handler.h" + +#include #include #include @@ -22,15 +24,55 @@ namespace mindspore { namespace ps { namespace comm { -void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); } +void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { MS_EXCEPTION_IF_NULL(buffer); + auto buffer_data = reinterpret_cast(buffer); + + while (num > 0) { + if (remaining_length_ == 0) { + for (int i = 0; i < 4 && num > 0; ++i) { + header_[++header_index_] = *(buffer_data + i); + --num; + if (header_index_ == 3) { + message_length_ = *reinterpret_cast(header_); + message_length_ = ntohl(message_length_); + remaining_length_ = message_length_; + message_buffer_.reset(new unsigned char[remaining_length_]); + buffer_data += i; + break; + } + } + } + + if (remaining_length_ > 0) { + uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; + remaining_length_ -= copy_len; + num -= copy_len; - if (message_callback_) { - message_callback_(buffer, num); + int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len); + last_copy_len_ += copy_len; + buffer_data += copy_len; + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + if (remaining_length_ == 0) { + CommMessage pb_message; + pb_message.ParseFromArray(reinterpret_cast(message_buffer_.get()), message_length_); + if (message_callback_) { + message_callback_(pb_message); + } + message_buffer_.reset(); + message_buffer_ = nullptr; + header_index_ = 0; + last_copy_len_ = 0; + } + } } } + } // namespace comm } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_message_handler.h b/mindspore/ccsrc/ps/comm/tcp_message_handler.h index 339e25a06a..58686c781e 100644 --- a/mindspore/ccsrc/ps/comm/tcp_message_handler.h +++ b/mindspore/ccsrc/ps/comm/tcp_message_handler.h @@ -19,26 +19,43 @@ #include #include +#include #include +#include #include "utils/log_adapter.h" +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" namespace mindspore { namespace ps { namespace comm { -using messageReceive = std::function; +using messageReceive = std::function; class TcpMessageHandler { public: - TcpMessageHandler() = default; + TcpMessageHandler() + : is_parsed_(false), + message_buffer_(nullptr), + message_length_(0), + remaining_length_(0), + header_index_(-1), + last_copy_len_(0) {} virtual ~TcpMessageHandler() = default; - void SetCallback(messageReceive cb); + void SetCallback(const messageReceive &cb); void ReceiveMessage(const void *buffer, size_t num); private: messageReceive message_callback_; + bool is_parsed_; + std::unique_ptr message_buffer_; + size_t message_length_; + uint32_t remaining_length_; + char header_[4]; + int header_index_; + uint32_t last_copy_len_; }; } // namespace comm } // namespace ps diff --git a/mindspore/ccsrc/ps/comm/tcp_server.cc b/mindspore/ccsrc/ps/comm/tcp_server.cc index bd5a9c8088..4cf70f3b2d 100644 --- a/mindspore/ccsrc/ps/comm/tcp_server.cc +++ b/mindspore/ccsrc/ps/comm/tcp_server.cc @@ -33,16 +33,12 @@ namespace mindspore { namespace ps { namespace comm { -void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) { - MS_EXCEPTION_IF_NULL(bev); - MS_EXCEPTION_IF_NULL(server); - buffer_event_ = const_cast(bev); - fd_ = fd; - server_ = const_cast(server); - - tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) { - OnServerReceiveMessage message_callback = server->GetServerReceiveMessage(); - if (message_callback) message_callback(*server, *this, buf, num); +void TcpConnection::InitConnection() { + tcp_message_handler_.SetCallback([&](const CommMessage &message) { + OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); + if (on_server_receive) { + on_server_receive(*server_, *this, message); + } }); } @@ -54,11 +50,26 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const { } } -TcpServer *TcpConnection::GetServer() const { return server_; } +TcpServer *TcpConnection::GetServer() const { return const_cast(server_); } -evutil_socket_t TcpConnection::GetFd() const { return fd_; } +const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } -TcpServer::TcpServer(std::string address, std::uint16_t port) +void TcpConnection::SendMessage(const CommMessage &message) const { + MS_EXCEPTION_IF_NULL(buffer_event_); + uint32_t buf_size = message.ByteSizeLong(); + std::vector serialized(buf_size); + message.SerializeToArray(serialized.data(), static_cast(buf_size)); + if (evbuffer_add(bufferevent_get_output(const_cast(buffer_event_)), &buf_size, + sizeof(buf_size)) == -1) { + MS_LOG(EXCEPTION) << "Event buffer add header failed!"; + } + if (evbuffer_add(bufferevent_get_output(const_cast(buffer_event_)), serialized.data(), + buf_size) == -1) { + MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; + } +} + +TcpServer::TcpServer(const std::string &address, std::uint16_t port) : base_(nullptr), signal_event_(nullptr), listener_(nullptr), @@ -74,7 +85,7 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon this->client_accept_ = client_accept; } -void TcpServer::InitServer() { +void TcpServer::Init() { base_ = event_base_new(); MS_EXCEPTION_IF_NULL(base_); CommUtil::CheckIp(server_address_); @@ -101,19 +112,26 @@ void TcpServer::InitServer() { } void TcpServer::Start() { - std::unique_lock l(connection_mutex_); + std::unique_lock lock(connection_mutex_); MS_LOG(INFO) << "Start tcp server!"; MS_EXCEPTION_IF_NULL(base_); int ret = event_base_dispatch(base_); - if (ret == 0) { - MS_LOG(INFO) << "Event base dispatch success!"; - } else if (ret == 1) { - MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; - } else if (ret == -1) { - MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; - } else { - MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; - } + MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; + MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) + << "Event base dispatch failed with no events pending or active!"; + MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; +} + +void TcpServer::StartWithNoBlock() { + std::unique_lock lock(connection_mutex_); + MS_LOG(INFO) << "Start tcp server with no block!"; + MS_EXCEPTION_IF_NULL(base_); + int ret = event_base_loop(base_, EVLOOP_NONBLOCK); + MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; + MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; + MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; } void TcpServer::Stop() { @@ -150,6 +168,8 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co void TcpServer::RemoveConnection(const evutil_socket_t &fd) { std::unique_lock lock(connection_mutex_); + TcpConnection *connection = const_cast(connections_.find(fd)->second); + delete connection; connections_.erase(fd); } @@ -166,10 +186,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st return; } - TcpConnection *conn = server->onCreateConnection(); + TcpConnection *conn = server->onCreateConnection(bev, fd); MS_EXCEPTION_IF_NULL(conn); - conn->InitConnection(fd, bev, server); + conn->InitConnection(); server->AddConnection(fd, conn); bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast(conn)); if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { @@ -177,17 +197,18 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st } } -TcpConnection *TcpServer::onCreateConnection() { +TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { TcpConnection *conn = nullptr; - if (client_accept_) - conn = const_cast(client_accept_(this)); - else - conn = new TcpConnection(); + if (client_accept_) { + conn = const_cast(client_accept_(*this)); + } else { + conn = new TcpConnection(bev, fd, this); + } return conn; } -OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; } +OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; } void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) { auto server = reinterpret_cast(data); @@ -207,9 +228,9 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { auto conn = static_cast(connection); struct evbuffer *buf = bufferevent_get_input(bev); char read_buffer[4096]; - auto read = 0; - while ((read = EVBUFFER_LENGTH(buf)) > 0) { - if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) { + while (EVBUFFER_LENGTH(buf) > 0) { + int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)); + if (read == -1) { MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; } conn->OnReadHandler(read_buffer, static_cast(read)); @@ -219,43 +240,47 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) { MS_EXCEPTION_IF_NULL(bev); MS_EXCEPTION_IF_NULL(data); + struct evbuffer *output = bufferevent_get_output(bev); + size_t remain = evbuffer_get_length(output); auto conn = reinterpret_cast(data); TcpServer *srv = conn->GetServer(); if (events & BEV_EVENT_EOF) { + MS_LOG(INFO) << "Event buffer end of file!"; // Notify about disconnection - if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); + if (srv->client_disconnection_) { + srv->client_disconnection_(*srv, *conn); + } // Free connection structures srv->RemoveConnection(conn->GetFd()); bufferevent_free(bev); } else if (events & BEV_EVENT_ERROR) { + MS_LOG(ERROR) << "Event buffer remain data: " << remain; // Free connection structures srv->RemoveConnection(conn->GetFd()); bufferevent_free(bev); // Notify about disconnection - if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); + if (srv->client_disconnection_) { + srv->client_disconnection_(*srv, *conn); + } } else { MS_LOG(ERROR) << "Unhandled event!"; } } -void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; } +void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } -void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) { - MS_EXCEPTION_IF_NULL(data); - auto mc = const_cast(conn); - mc.SendMessage(data, num); -} - -void TcpServer::SendMessage(const void *data, size_t num) { - MS_EXCEPTION_IF_NULL(data); +void TcpServer::SendMessage(const CommMessage &message) { std::unique_lock lock(connection_mutex_); for (auto it = connections_.begin(); it != connections_.end(); ++it) { - SendMessage(*it->second, data, num); + SendMessage(*it->second, message); } } + +void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } + } // namespace comm } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_server.h b/mindspore/ccsrc/ps/comm/tcp_server.h index ccb9ef5e8a..f88cc954ab 100644 --- a/mindspore/ccsrc/ps/comm/tcp_server.h +++ b/mindspore/ccsrc/ps/comm/tcp_server.h @@ -27,6 +27,8 @@ #include #include #include +#include +#include #include "utils/log_adapter.h" #include "ps/comm/tcp_message_handler.h" @@ -38,46 +40,49 @@ namespace comm { class TcpServer; class TcpConnection { public: - TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {} + explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) + : buffer_event_(bev), fd_(0), server_(server) {} virtual ~TcpConnection() = default; - virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server); - void SendMessage(const void *buffer, size_t num) const; + virtual void InitConnection(); + virtual void SendMessage(const void *buffer, size_t num) const; + void SendMessage(const CommMessage &message) const; virtual void OnReadHandler(const void *buffer, size_t numBytes); TcpServer *GetServer() const; - evutil_socket_t GetFd() const; + const evutil_socket_t &GetFd() const; protected: - TcpMessageHandler tcp_message_handler_; struct bufferevent *buffer_event_; evutil_socket_t fd_; - TcpServer *server_; + const TcpServer *server_; + TcpMessageHandler tcp_message_handler_; }; using OnServerReceiveMessage = - std::function; + std::function; class TcpServer { public: - using OnConnected = std::function; - using OnDisconnected = std::function; - using OnAccepted = std::function; + using OnConnected = std::function; + using OnDisconnected = std::function; + using OnAccepted = std::function; - explicit TcpServer(std::string address, std::uint16_t port); + explicit TcpServer(const std::string &address, std::uint16_t port); virtual ~TcpServer(); void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, const OnAccepted &client_accept); - void InitServer(); + void Init(); void Start(); + void StartWithNoBlock(); void Stop(); void SendToAllClients(const char *data, size_t len); void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); void RemoveConnection(const evutil_socket_t &fd); - void ReceiveMessage(const OnServerReceiveMessage &cb); - static void SendMessage(const TcpConnection &conn, const void *data, size_t num); - void SendMessage(const void *data, size_t num); - OnServerReceiveMessage GetServerReceiveMessage() const; + OnServerReceiveMessage GetServerReceive() const; + void SetMessageCallback(const OnServerReceiveMessage &cb); + static void SendMessage(const TcpConnection &conn, const CommMessage &message); + void SendMessage(const CommMessage &message); protected: static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, @@ -85,9 +90,8 @@ class TcpServer { static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); static void ReadCallback(struct bufferevent *, void *connection); static void EventCallback(struct bufferevent *, std::int16_t events, void *server); - virtual TcpConnection *onCreateConnection(); + virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); - private: struct event_base *base_; struct event *signal_event_; struct evconnlistener *listener_; @@ -101,6 +105,7 @@ class TcpServer { std::recursive_mutex connection_mutex_; OnServerReceiveMessage message_callback_; }; + } // namespace comm } // namespace ps } // namespace mindspore diff --git a/tests/ut/cpp/ps/comm/http_server_test.cc b/tests/ut/cpp/ps/comm/http_server_test.cc index c1ce166959..7c2f5dc6bc 100644 --- a/tests/ut/cpp/ps/comm/http_server_test.cc +++ b/tests/ut/cpp/ps/comm/http_server_test.cc @@ -24,6 +24,7 @@ #include #include #include +#include namespace mindspore { namespace ps { @@ -31,7 +32,9 @@ namespace comm { class TestHttpServer : public UT::Common { public: - TestHttpServer() = default; + TestHttpServer() : server_(nullptr) {} + + virtual ~TestHttpServer() = default; static void testGetHandler(std::shared_ptr resp) { std::string host = resp->GetRequestHost(); @@ -57,7 +60,7 @@ class TestHttpServer : public UT::Common { } void SetUp() override { - server_ = new HttpServer("0.0.0.0", 9999); + server_ = std::make_unique("0.0.0.0", 9999); OnRequestReceive http_get_func = std::bind( [](std::shared_ptr resp) { EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); @@ -106,7 +109,7 @@ class TestHttpServer : public UT::Common { } private: - HttpServer *server_; + std::unique_ptr server_; }; TEST_F(TestHttpServer, httpGetQequest) { @@ -143,13 +146,13 @@ TEST_F(TestHttpServer, messageHandler) { } TEST_F(TestHttpServer, portErrorNoException) { - HttpServer *server_exception = new HttpServer("0.0.0.0", -1); + auto server_exception = std::make_unique("0.0.0.0", -1); OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func)); } TEST_F(TestHttpServer, addressException) { - HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998); + auto server_exception = std::make_unique("12344.0.0.0", 9998); OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); } diff --git a/tests/ut/cpp/ps/comm/tcp_client_tests.cc b/tests/ut/cpp/ps/comm/tcp_client_tests.cc index 424e7cc286..a8b2b2ef3a 100644 --- a/tests/ut/cpp/ps/comm/tcp_client_tests.cc +++ b/tests/ut/cpp/ps/comm/tcp_client_tests.cc @@ -14,6 +14,8 @@ * limitations under the License. */ +#include + #include "common/common_test.h" #include "ps/comm/tcp_client.h" @@ -26,19 +28,19 @@ class TestTcpClient : public UT::Common { }; TEST_F(TestTcpClient, InitClientIPError) { - auto client = new TcpClient("127.0.0.13543", 9000); - client->ReceiveMessage( - [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); + auto client = std::make_unique("127.0.0.13543", 9000); + + client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); - ASSERT_THROW(client->InitTcpClient(), std::exception); + ASSERT_THROW(client->Init(), std::exception); } TEST_F(TestTcpClient, InitClientPortErrorNoException) { - auto client = new TcpClient("127.0.0.1", -1); - client->ReceiveMessage( - [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); + auto client = std::make_unique("127.0.0.1", -1); + + client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); - EXPECT_NO_THROW(client->InitTcpClient()); + EXPECT_NO_THROW(client->Init()); } } // namespace comm diff --git a/tests/ut/cpp/ps/comm/tcp_server_tests.cc b/tests/ut/cpp/ps/comm/tcp_pb_server_test.cc similarity index 56% rename from tests/ut/cpp/ps/comm/tcp_server_tests.cc rename to tests/ut/cpp/ps/comm/tcp_pb_server_test.cc index 1508c5b5ac..e041e2046f 100644 --- a/tests/ut/cpp/ps/comm/tcp_server_tests.cc +++ b/tests/ut/cpp/ps/comm/tcp_pb_server_test.cc @@ -18,6 +18,7 @@ #include "ps/comm/tcp_server.h" #include "common/common_test.h" +#include #include namespace mindspore { @@ -25,16 +26,20 @@ namespace ps { namespace comm { class TestTcpServer : public UT::Common { public: - TestTcpServer() = default; + TestTcpServer() : client_(nullptr), server_(nullptr) {} + virtual ~TestTcpServer() = default; + void SetUp() override { - server_ = new TcpServer("127.0.0.1", 9000); + server_ = std::make_unique("127.0.0.1", 9998); std::unique_ptr http_server_thread_(nullptr); http_server_thread_ = std::make_unique([&]() { - server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) { - EXPECT_STREQ(std::string(reinterpret_cast(buffer), num).c_str(), "TCP_MESSAGE"); - server.SendMessage(conn, buffer, num); + server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + KVMessage kv_message; + kv_message.ParseFromString(message.data()); + EXPECT_EQ(2, kv_message.keys_size()); + server.SendMessage(conn, message); }); - server_->InitServer(); + server_->Init(); server_->Start(); }); http_server_thread_->detach(); @@ -47,21 +52,32 @@ class TestTcpServer : public UT::Common { server_->Stop(); } - TcpClient *client_; - TcpServer *server_; - const std::string test_message_ = "TCP_MESSAGE"; + std::unique_ptr client_; + std::unique_ptr server_; }; TEST_F(TestTcpServer, ServerSendMessage) { - client_ = new TcpClient("127.0.0.1", 9000); + client_ = std::make_unique("127.0.0.1", 9998); std::unique_ptr http_client_thread(nullptr); http_client_thread = std::make_unique([&]() { - client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) { - EXPECT_STREQ(std::string(reinterpret_cast(buffer), num).c_str(), "TCP_MESSAGE"); + client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { + KVMessage kv_message; + kv_message.ParseFromString(message.data()); + EXPECT_EQ(2, kv_message.keys_size()); }); - client_->InitTcpClient(); - client_->SendMessage(test_message_.c_str(), test_message_.size()); + client_->Init(); + + CommMessage comm_message; + KVMessage kv_message; + std::vector keys{1, 2}; + std::vector values{3, 4}; + *kv_message.mutable_keys() = {keys.begin(), keys.end()}; + *kv_message.mutable_values() = {values.begin(), values.end()}; + + comm_message.set_data(kv_message.SerializeAsString()); + client_->SendMessage(comm_message); + client_->Start(); }); http_client_thread->detach();