From 98bbb4ef7a88ae70292bbedfaf7994f0f30b7d60 Mon Sep 17 00:00:00 2001 From: anancds Date: Tue, 20 Oct 2020 15:30:00 +0800 Subject: [PATCH] added tcp base on libevent --- mindspore/ccsrc/ps/CMakeLists.txt | 4 + mindspore/ccsrc/ps/comm/comm_util.cc | 50 ++++ mindspore/ccsrc/ps/comm/comm_util.h | 49 ++++ .../ccsrc/ps/comm/http_message_handler.h | 2 +- mindspore/ccsrc/ps/comm/http_server.cc | 47 +--- mindspore/ccsrc/ps/comm/http_server.h | 15 +- mindspore/ccsrc/ps/comm/tcp_client.cc | 220 +++++++++++++++ mindspore/ccsrc/ps/comm/tcp_client.h | 77 ++++++ .../ccsrc/ps/comm/tcp_message_handler.cc | 36 +++ mindspore/ccsrc/ps/comm/tcp_message_handler.h | 47 ++++ mindspore/ccsrc/ps/comm/tcp_server.cc | 259 ++++++++++++++++++ mindspore/ccsrc/ps/comm/tcp_server.h | 107 ++++++++ tests/ut/cpp/ps/comm/http_server_test.cc | 58 +++- tests/ut/cpp/ps/comm/tcp_client_tests.cc | 46 ++++ tests/ut/cpp/ps/comm/tcp_server_tests.cc | 71 +++++ 15 files changed, 1026 insertions(+), 62 deletions(-) create mode 100644 mindspore/ccsrc/ps/comm/comm_util.cc create mode 100644 mindspore/ccsrc/ps/comm/comm_util.h create mode 100644 mindspore/ccsrc/ps/comm/tcp_client.cc create mode 100644 mindspore/ccsrc/ps/comm/tcp_client.h create mode 100644 mindspore/ccsrc/ps/comm/tcp_message_handler.cc create mode 100644 mindspore/ccsrc/ps/comm/tcp_message_handler.h create mode 100644 mindspore/ccsrc/ps/comm/tcp_server.cc create mode 100644 mindspore/ccsrc/ps/comm/tcp_server.h create mode 100644 tests/ut/cpp/ps/comm/tcp_client_tests.cc create mode 100644 tests/ut/cpp/ps/comm/tcp_server_tests.cc diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 9bc3f31312..e8e412734f 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -7,6 +7,10 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "util.cc") list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc") + list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc") + list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc") + list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc") + list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc") endif() set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) diff --git a/mindspore/ccsrc/ps/comm/comm_util.cc b/mindspore/ccsrc/ps/comm/comm_util.cc new file mode 100644 index 0000000000..1b3be35edc --- /dev/null +++ b/mindspore/ccsrc/ps/comm/comm_util.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/comm_util.h" + +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace comm { + +bool CommUtil::CheckIpWithRegex(const std::string &ip) { + std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); + std::smatch res; + if (regex_match(ip, res, pattern)) { + return true; + } + return false; +} + +void CommUtil::CheckIp(const std::string &ip) { + if (!CheckIpWithRegex(ip)) { + MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!"; + } + int64_t uAddr = inet_addr(ip.c_str()); + if (INADDR_NONE == uAddr) { + MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; + } +} +} // namespace comm +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/comm_util.h b/mindspore/ccsrc/ps/comm/comm_util.h new file mode 100644 index 0000000000..46455671e4 --- /dev/null +++ b/mindspore/ccsrc/ps/comm/comm_util.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ +#define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace comm { + +class CommUtil { + public: + static bool CheckIpWithRegex(const std::string &ip); + static void CheckIp(const std::string &ip); +}; +} // namespace comm +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ diff --git a/mindspore/ccsrc/ps/comm/http_message_handler.h b/mindspore/ccsrc/ps/comm/http_message_handler.h index ad9bc6a139..c2b338700d 100644 --- a/mindspore/ccsrc/ps/comm/http_message_handler.h +++ b/mindspore/ccsrc/ps/comm/http_message_handler.h @@ -38,7 +38,7 @@ namespace mindspore { namespace ps { namespace comm { -typedef std::map> HttpHeaders; +using HttpHeaders = std::map>; class HttpMessageHandler { public: diff --git a/mindspore/ccsrc/ps/comm/http_server.cc b/mindspore/ccsrc/ps/comm/http_server.cc index f99ec69e4a..35ed6edd06 100644 --- a/mindspore/ccsrc/ps/comm/http_server.cc +++ b/mindspore/ccsrc/ps/comm/http_server.cc @@ -16,6 +16,7 @@ #include "ps/comm/http_server.h" #include "ps/comm/http_message_handler.h" +#include "ps/comm/comm_util.h" #ifdef WIN32 #include @@ -41,28 +42,10 @@ namespace mindspore { namespace ps { namespace comm { -HttpServer::~HttpServer() { - if (event_http_) { - evhttp_free(event_http_); - event_http_ = nullptr; - } - if (event_base_) { - event_base_free(event_base_); - event_base_ = nullptr; - } -} +HttpServer::~HttpServer() { Stop(); } bool HttpServer::InitServer() { - if (!CheckIp(server_address_)) { - MS_LOG(EXCEPTION) << "Server address" << server_address_ << " illegal!"; - } - int64_t uAddr = inet_addr(server_address_.c_str()); - if (INADDR_NONE == uAddr) { - MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; - } - if (server_port_ <= 0) { - MS_LOG(EXCEPTION) << "Server port:" << server_port_ << " illegal!"; - } + CommUtil::CheckIp(server_address_); event_base_ = event_base_new(); MS_EXCEPTION_IF_NULL(event_base_); @@ -76,15 +59,6 @@ bool HttpServer::InitServer() { return true; } -bool HttpServer::CheckIp(const std::string &ip) { - std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); - std::smatch res; - if (regex_match(ip, res, pattern)) { - return true; - } - return false; -} - void HttpServer::SetTimeOut(int seconds) { MS_EXCEPTION_IF_NULL(event_http_); if (seconds < 0) { @@ -93,7 +67,7 @@ void HttpServer::SetTimeOut(int seconds) { evhttp_set_timeout(event_http_, seconds); } -void HttpServer::SetAllowedMethod(HttpMethodsSet methods) { +void HttpServer::SetAllowedMethod(u_int16_t methods) { MS_EXCEPTION_IF_NULL(event_http_); evhttp_set_allowed_methods(event_http_, methods); } @@ -114,12 +88,11 @@ void HttpServer::SetMaxBodySize(size_t num) { evhttp_set_max_body_size(event_http_, num); } -bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) { +bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *function) { if ((!is_init_) && (!InitServer())) { MS_LOG(EXCEPTION) << "Init http server failed!"; } - HandlerFunc func = function; - if (!func) { + if (!function) { return false; } @@ -128,15 +101,13 @@ bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) { MS_EXCEPTION_IF_NULL(arg); HttpMessageHandler httpReq(req); httpReq.InitHttpMessage(); - handle_t *f = reinterpret_cast(arg); - f(&httpReq); + OnRequestReceive *func = reinterpret_cast(arg); + (*func)(&httpReq); }; - handle_t **pph = func.target(); - MS_EXCEPTION_IF_NULL(pph); MS_EXCEPTION_IF_NULL(event_http_); // O SUCCESS,-1 ALREADY_EXIST,-2 FAILURE - int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast(*pph)); + int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast(function)); if (ret == 0) { MS_LOG(INFO) << "Ev http register handle of:" << url.c_str() << " success."; } else if (ret == -1) { diff --git a/mindspore/ccsrc/ps/comm/http_server.h b/mindspore/ccsrc/ps/comm/http_server.h index 79d1387e9e..65fa29597e 100644 --- a/mindspore/ccsrc/ps/comm/http_server.h +++ b/mindspore/ccsrc/ps/comm/http_server.h @@ -48,26 +48,21 @@ typedef enum eHttpMethod { HM_PATCH = 1 << 8 } HttpMethod; -typedef u_int16_t HttpMethodsSet; - -typedef void(handle_t)(HttpMessageHandler *); - class HttpServer { public: // Server address only support IPV4 now, and should be in format of "x.x.x.x" - explicit HttpServer(const std::string &address, std::int16_t port) + explicit HttpServer(const std::string &address, std::uint16_t port) : server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {} ~HttpServer(); - typedef std::function HandlerFunc; + using OnRequestReceive = std::function; bool InitServer(); - static bool CheckIp(const std::string &ip); void SetTimeOut(int seconds = 5); // Default allowed methods: GET, POST, HEAD, PUT, DELETE - void SetAllowedMethod(HttpMethodsSet methods); + void SetAllowedMethod(u_int16_t methods); // Default to ((((unsigned long long)0xffffffffUL) << 32) | 0xffffffffUL) void SetMaxHeaderSize(std::size_t num); @@ -76,7 +71,7 @@ class HttpServer { void SetMaxBodySize(std::size_t num); // Return: true if success, false if failed, check log to find failure reason - bool RegisterRoute(const std::string &url, handle_t *func); + bool RegisterRoute(const std::string &url, OnRequestReceive *func); bool UnRegisterRoute(const std::string &url); bool Start(); @@ -84,7 +79,7 @@ class HttpServer { private: std::string server_address_; - std::int16_t server_port_; + std::uint16_t server_port_; struct event_base *event_base_; struct evhttp *event_http_; bool is_init_; diff --git a/mindspore/ccsrc/ps/comm/tcp_client.cc b/mindspore/ccsrc/ps/comm/tcp_client.cc new file mode 100644 index 0000000000..3a6e184b6b --- /dev/null +++ b/mindspore/ccsrc/ps/comm/tcp_client.cc @@ -0,0 +1,220 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/tcp_client.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ps/comm/comm_util.h" + +namespace mindspore { +namespace ps { +namespace comm { + +TcpClient::TcpClient(std::string address, std::uint16_t port) + : event_base_(nullptr), + event_timeout_(nullptr), + buffer_event_(nullptr), + server_address_(std::move(address)), + server_port_(port) { + message_handler_.SetCallback([this](const void *buf, size_t num) { + if (buf == nullptr) { + if (disconnected_callback_) disconnected_callback_(*this, 200); + Stop(); + } + if (message_callback_) message_callback_(*this, buf, num); + }); +} + +TcpClient::~TcpClient() { Stop(); } + +std::string TcpClient::GetServerAddress() const { return server_address_; } + +void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, + const OnTimeout &timeout) { + connected_callback_ = conn; + disconnected_callback_ = disconn; + read_callback_ = read; + timeout_callback_ = timeout; +} + +void TcpClient::InitTcpClient() { + if (buffer_event_) { + return; + } + CommUtil::CheckIp(server_address_); + + event_base_ = event_base_new(); + MS_EXCEPTION_IF_NULL(event_base_); + + sockaddr_in sin{}; + if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { + MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!"; + } + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); + sin.sin_port = htons(server_port_); + + buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); + MS_EXCEPTION_IF_NULL(buffer_event_); + + bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); + if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) { + MS_LOG(EXCEPTION) << "buffer event enable read and write failed!"; + } + + int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast(&sin), sizeof(sin)); + if (result_code < 0) { + MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!"; + } +} + +void TcpClient::StartWithDelay(int seconds) { + if (buffer_event_) { + return; + } + + event_base_ = event_base_new(); + + timeval timeout_value{}; + timeout_value.tv_sec = seconds; + timeout_value.tv_usec = 0; + + event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this); + if (evtimer_add(event_timeout_, &timeout_value) == -1) { + MS_LOG(EXCEPTION) << "event timeout failed!"; + } +} + +void TcpClient::Stop() { + if (buffer_event_) { + bufferevent_free(buffer_event_); + buffer_event_ = nullptr; + } + + if (event_timeout_) { + event_free(event_timeout_); + event_timeout_ = nullptr; + } + + if (event_base_) { + event_base_free(event_base_); + event_base_ = nullptr; + } +} + +void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) { + const int one = 1; + int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int)); + if (ret < 0) { + MS_LOG(EXCEPTION) << "Set socket no delay failed!"; + } +} + +void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) { + MS_EXCEPTION_IF_NULL(arg); + auto tcp_client = reinterpret_cast(arg); + tcp_client->InitTcpClient(); +} + +void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { + MS_EXCEPTION_IF_NULL(bev); + MS_EXCEPTION_IF_NULL(ctx); + auto tcp_client = reinterpret_cast(ctx); + struct evbuffer *input = bufferevent_get_input(const_cast(bev)); + MS_EXCEPTION_IF_NULL(input); + + char read_buffer[4096]; + int read = 0; + + while ((read = EVBUFFER_LENGTH(input)) > 0) { + if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) { + MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; + } + tcp_client->OnReadHandler(read_buffer, read); + } +} + +void TcpClient::OnReadHandler(const void *buf, size_t num) { + MS_EXCEPTION_IF_NULL(buf); + if (read_callback_) { + read_callback_(*this, buf, num); + } + message_handler_.ReceiveMessage(buf, num); +} + +void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { + MS_EXCEPTION_IF_NULL(bev); + MS_EXCEPTION_IF_NULL(ptr); + auto tcp_client = reinterpret_cast(ptr); + if (events & BEV_EVENT_CONNECTED) { + // Connected + if (tcp_client->connected_callback_) { + tcp_client->connected_callback_(*tcp_client); + } + evutil_socket_t fd = bufferevent_getfd(const_cast(bev)); + SetTcpNoDelay(fd); + MS_LOG(INFO) << "Client connected!"; + } else if (events & BEV_EVENT_ERROR) { + MS_LOG(ERROR) << "Client connected error!"; + if (tcp_client->disconnected_callback_) { + tcp_client->disconnected_callback_(*tcp_client, errno); + } + } else if (events & BEV_EVENT_EOF) { + MS_LOG(ERROR) << "Client connected end of file"; + if (tcp_client->disconnected_callback_) { + tcp_client->disconnected_callback_(*tcp_client, 0); + } + } +} + +void TcpClient::Start() { + MS_EXCEPTION_IF_NULL(event_base_); + int ret = event_base_dispatch(event_base_); + if (ret == 0) { + MS_LOG(INFO) << "Event base dispatch success!"; + } else if (ret == 1) { + MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; + } else if (ret == -1) { + MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; + } else { + MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; + } +} + +void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; } + +void TcpClient::SendMessage(const void *buf, size_t num) const { + MS_EXCEPTION_IF_NULL(buffer_event_); + if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) { + MS_LOG(EXCEPTION) << "event buffer add failed!"; + } +} +} // namespace comm +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_client.h b/mindspore/ccsrc/ps/comm/tcp_client.h new file mode 100644 index 0000000000..49d7478dab --- /dev/null +++ b/mindspore/ccsrc/ps/comm/tcp_client.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ +#define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ + +#include "ps/comm/tcp_message_handler.h" + +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace comm { + +class TcpClient { + public: + using OnMessage = std::function; + using OnConnected = std::function; + using OnDisconnected = std::function; + using OnRead = std::function; + using OnTimeout = std::function; + + explicit TcpClient(std::string address, std::uint16_t port); + virtual ~TcpClient(); + + std::string GetServerAddress() const; + void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, + const OnTimeout &timeout); + void InitTcpClient(); + void StartWithDelay(int seconds); + void Stop(); + void ReceiveMessage(const OnMessage &cb); + void SendMessage(const void *buf, size_t num) const; + void Start(); + + protected: + static void SetTcpNoDelay(const evutil_socket_t &fd); + static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg); + static void ReadCallback(struct bufferevent *bev, void *ctx); + static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); + virtual void OnReadHandler(const void *buf, size_t num); + + private: + TcpMessageHandler message_handler_; + OnMessage message_callback_; + OnConnected connected_callback_; + OnDisconnected disconnected_callback_; + OnRead read_callback_; + OnTimeout timeout_callback_; + + event_base *event_base_; + event *event_timeout_; + bufferevent *buffer_event_; + + std::string server_address_; + std::uint16_t server_port_; +}; +} // namespace comm +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ diff --git a/mindspore/ccsrc/ps/comm/tcp_message_handler.cc b/mindspore/ccsrc/ps/comm/tcp_message_handler.cc new file mode 100644 index 0000000000..5755802346 --- /dev/null +++ b/mindspore/ccsrc/ps/comm/tcp_message_handler.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/tcp_message_handler.h" +#include +#include + +namespace mindspore { +namespace ps { +namespace comm { + +void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); } + +void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { + MS_EXCEPTION_IF_NULL(buffer); + + if (message_callback_) { + message_callback_(buffer, num); + } +} +} // namespace comm +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_message_handler.h b/mindspore/ccsrc/ps/comm/tcp_message_handler.h new file mode 100644 index 0000000000..339e25a06a --- /dev/null +++ b/mindspore/ccsrc/ps/comm/tcp_message_handler.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ +#define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ + +#include +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace comm { + +using messageReceive = std::function; + +class TcpMessageHandler { + public: + TcpMessageHandler() = default; + virtual ~TcpMessageHandler() = default; + + void SetCallback(messageReceive cb); + void ReceiveMessage(const void *buffer, size_t num); + + private: + messageReceive message_callback_; +}; +} // namespace comm +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ diff --git a/mindspore/ccsrc/ps/comm/tcp_server.cc b/mindspore/ccsrc/ps/comm/tcp_server.cc new file mode 100644 index 0000000000..6ef89c6921 --- /dev/null +++ b/mindspore/ccsrc/ps/comm/tcp_server.cc @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/tcp_server.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ps/comm/comm_util.h" + +namespace mindspore { +namespace ps { +namespace comm { + +void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) { + MS_EXCEPTION_IF_NULL(bev); + MS_EXCEPTION_IF_NULL(server); + buffer_event_ = const_cast(bev); + fd_ = fd; + server_ = const_cast(server); + + tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) { + OnServerReceiveMessage message_callback = server->GetServerReceiveMessage(); + if (message_callback) message_callback(*server, *this, buf, num); + }); +} + +void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } + +void TcpConnection::SendMessage(const void *buffer, size_t num) const { + if (bufferevent_write(buffer_event_, buffer, num) == -1) { + MS_LOG(ERROR) << "Write message to buffer event failed!"; + } +} + +TcpServer *TcpConnection::GetServer() const { return server_; } + +evutil_socket_t TcpConnection::GetFd() const { return fd_; } + +TcpServer::TcpServer(std::string address, std::uint16_t port) + : base_(nullptr), + signal_event_(nullptr), + listener_(nullptr), + server_address_(std::move(address)), + server_port_(port) {} + +TcpServer::~TcpServer() { Stop(); } + +void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, + const OnAccepted &client_accept) { + this->client_connection_ = client_conn; + this->client_disconnection_ = client_disconn; + this->client_accept_ = client_accept; +} + +void TcpServer::InitServer() { + base_ = event_base_new(); + MS_EXCEPTION_IF_NULL(base_); + CommUtil::CheckIp(server_address_); + + struct sockaddr_in sin {}; + if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { + MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!"; + } + sin.sin_family = AF_INET; + sin.sin_port = htons(server_port_); + sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); + + listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast(this), + LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1, + reinterpret_cast(&sin), sizeof(sin)); + + MS_EXCEPTION_IF_NULL(listener_); + + signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast(this)); + MS_EXCEPTION_IF_NULL(signal_event_); + if (event_add(signal_event_, nullptr) < 0) { + MS_LOG(EXCEPTION) << "Cannot create signal event."; + } +} + +void TcpServer::Start() { + std::unique_lock l(connection_mutex_); + MS_EXCEPTION_IF_NULL(base_); + int ret = event_base_dispatch(base_); + if (ret == 0) { + MS_LOG(INFO) << "Event base dispatch success!"; + } else if (ret == 1) { + MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; + } else if (ret == -1) { + MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; + } else { + MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; + } +} + +void TcpServer::Stop() { + if (signal_event_ != nullptr) { + event_free(signal_event_); + signal_event_ = nullptr; + } + + if (listener_ != nullptr) { + evconnlistener_free(listener_); + listener_ = nullptr; + } + + if (base_ != nullptr) { + event_base_free(base_); + base_ = nullptr; + } +} + +void TcpServer::SendToAllClients(const char *data, size_t len) { + MS_EXCEPTION_IF_NULL(data); + std::unique_lock lock(connection_mutex_); + for (auto it = connections_.begin(); it != connections_.end(); ++it) { + it->second->SendMessage(data, len); + } +} + +void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { + MS_EXCEPTION_IF_NULL(connection); + std::unique_lock lock(connection_mutex_); + connections_.insert(std::make_pair(fd, connection)); +} + +void TcpServer::RemoveConnection(const evutil_socket_t &fd) { + std::unique_lock lock(connection_mutex_); + connections_.erase(fd); +} + +void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) { + auto server = reinterpret_cast(data); + auto base = reinterpret_cast(server->base_); + MS_EXCEPTION_IF_NULL(server); + MS_EXCEPTION_IF_NULL(base); + + struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); + if (!bev) { + MS_LOG(ERROR) << "Error constructing buffer event!"; + event_base_loopbreak(base); + return; + } + + TcpConnection *conn = server->onCreateConnection(); + MS_EXCEPTION_IF_NULL(conn); + + conn->InitConnection(fd, bev, server); + server->AddConnection(fd, conn); + bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast(conn)); + if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { + MS_LOG(EXCEPTION) << "buffer event enable read and write failed!"; + } +} + +TcpConnection *TcpServer::onCreateConnection() { + TcpConnection *conn = nullptr; + if (client_accept_) + conn = const_cast(client_accept_(this)); + else + conn = new TcpConnection(); + + return conn; +} + +OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; } + +void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) { + auto server = reinterpret_cast(data); + MS_EXCEPTION_IF_NULL(server); + struct event_base *base = server->base_; + struct timeval delay = {0, 0}; + MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds."; + if (event_base_loopexit(base, &delay) == -1) { + MS_LOG(EXCEPTION) << "event base loop exit failed."; + } +} + +void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { + MS_EXCEPTION_IF_NULL(bev); + MS_EXCEPTION_IF_NULL(connection); + + auto conn = static_cast(connection); + struct evbuffer *buf = bufferevent_get_input(bev); + char read_buffer[4096]; + auto read = 0; + while ((read = EVBUFFER_LENGTH(buf)) > 0) { + if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) { + MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; + } + conn->OnReadHandler(read_buffer, static_cast(read)); + } +} + +void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) { + MS_EXCEPTION_IF_NULL(bev); + MS_EXCEPTION_IF_NULL(data); + auto conn = reinterpret_cast(data); + TcpServer *srv = conn->GetServer(); + + if (events & BEV_EVENT_EOF) { + // Notify about disconnection + if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); + // Free connection structures + srv->RemoveConnection(conn->GetFd()); + bufferevent_free(bev); + } else if (events & BEV_EVENT_ERROR) { + // Free connection structures + srv->RemoveConnection(conn->GetFd()); + bufferevent_free(bev); + + // Notify about disconnection + if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); + } else { + MS_LOG(ERROR) << "unhandled event!"; + } +} + +void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; } + +void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) { + MS_EXCEPTION_IF_NULL(data); + auto mc = const_cast(conn); + mc.SendMessage(data, num); +} + +void TcpServer::SendMessage(const void *data, size_t num) { + MS_EXCEPTION_IF_NULL(data); + std::unique_lock lock(connection_mutex_); + + for (auto it = connections_.begin(); it != connections_.end(); ++it) { + SendMessage(*it->second, data, num); + } +} +} // namespace comm +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_server.h b/mindspore/ccsrc/ps/comm/tcp_server.h new file mode 100644 index 0000000000..ccb9ef5e8a --- /dev/null +++ b/mindspore/ccsrc/ps/comm/tcp_server.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ +#define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils/log_adapter.h" +#include "ps/comm/tcp_message_handler.h" + +namespace mindspore { +namespace ps { +namespace comm { + +class TcpServer; +class TcpConnection { + public: + TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {} + virtual ~TcpConnection() = default; + + virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server); + void SendMessage(const void *buffer, size_t num) const; + virtual void OnReadHandler(const void *buffer, size_t numBytes); + TcpServer *GetServer() const; + evutil_socket_t GetFd() const; + + protected: + TcpMessageHandler tcp_message_handler_; + struct bufferevent *buffer_event_; + evutil_socket_t fd_; + TcpServer *server_; +}; + +using OnServerReceiveMessage = + std::function; + +class TcpServer { + public: + using OnConnected = std::function; + using OnDisconnected = std::function; + using OnAccepted = std::function; + + explicit TcpServer(std::string address, std::uint16_t port); + virtual ~TcpServer(); + + void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, + const OnAccepted &client_accept); + void InitServer(); + void Start(); + void Stop(); + void SendToAllClients(const char *data, size_t len); + void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); + void RemoveConnection(const evutil_socket_t &fd); + void ReceiveMessage(const OnServerReceiveMessage &cb); + static void SendMessage(const TcpConnection &conn, const void *data, size_t num); + void SendMessage(const void *data, size_t num); + OnServerReceiveMessage GetServerReceiveMessage() const; + + protected: + static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, + int socklen, void *server); + static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); + static void ReadCallback(struct bufferevent *, void *connection); + static void EventCallback(struct bufferevent *, std::int16_t events, void *server); + virtual TcpConnection *onCreateConnection(); + + private: + struct event_base *base_; + struct event *signal_event_; + struct evconnlistener *listener_; + std::string server_address_; + std::uint16_t server_port_; + + std::map connections_; + OnConnected client_connection_; + OnDisconnected client_disconnection_; + OnAccepted client_accept_; + std::recursive_mutex connection_mutex_; + OnServerReceiveMessage message_callback_; +}; +} // namespace comm +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ diff --git a/tests/ut/cpp/ps/comm/http_server_test.cc b/tests/ut/cpp/ps/comm/http_server_test.cc index 514398cf82..1ecb3d0ea6 100644 --- a/tests/ut/cpp/ps/comm/http_server_test.cc +++ b/tests/ut/cpp/ps/comm/http_server_test.cc @@ -31,7 +31,7 @@ namespace comm { class TestHttpServer : public UT::Common { public: - TestHttpServer() {} + TestHttpServer() = default; static void testGetHandler(HttpMessageHandler *resp) { std::string host = resp->GetRequestHost(); @@ -58,16 +58,44 @@ class TestHttpServer : public UT::Common { void SetUp() override { server_ = new HttpServer("0.0.0.0", 9999); - server_->RegisterRoute("/httpget", [](HttpMessageHandler *resp) { - EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); - EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1"); - EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1"); - EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget"); - resp->QuickResponse(200, "get request success!\n"); - }); - server_->RegisterRoute("/handler", TestHttpServer::testGetHandler); + std::function http_get_func = std::bind( + [](HttpMessageHandler *resp) { + EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); + EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1"); + EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1"); + EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget"); + resp->QuickResponse(200, "get request success!\n"); + }, + std::placeholders::_1); + + std::function http_handler_func = std::bind( + [](HttpMessageHandler *resp) { + std::string host = resp->GetRequestHost(); + EXPECT_STREQ(host.c_str(), "127.0.0.1"); + + std::string path_param = resp->GetPathParam("key1"); + std::string header_param = resp->GetHeadParam("headerKey"); + std::string post_param = resp->GetPostParam("postKey"); + std::string post_message = resp->GetPostMsg(); + EXPECT_STREQ(path_param.c_str(), "value1"); + EXPECT_STREQ(header_param.c_str(), "headerValue"); + EXPECT_STREQ(post_param.c_str(), "postValue"); + EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); + + const std::string rKey("headKey"); + const std::string rVal("headValue"); + const std::string rBody("post request success!\n"); + resp->AddRespHeadParam(rKey, rVal); + resp->AddRespString(rBody); + + resp->SetRespCode(200); + resp->SendResponse(); + }, + std::placeholders::_1); + server_->RegisterRoute("/httpget", &http_get_func); + server_->RegisterRoute("/handler", &http_handler_func); std::unique_ptr http_server_thread_(nullptr); - http_server_thread_.reset(new std::thread([&]() { server_->Start(); })); + http_server_thread_ = std::make_unique([&]() { server_->Start(); }); http_server_thread_->detach(); } @@ -110,14 +138,18 @@ TEST_F(TestHttpServer, messageHandler) { pclose(file); } -TEST_F(TestHttpServer, portException) { +TEST_F(TestHttpServer, portErrorNoException) { HttpServer *server_exception = new HttpServer("0.0.0.0", -1); - ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception); + std::function http_handler_func = + std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); + EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func)); } TEST_F(TestHttpServer, addressException) { HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998); - ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception); + std::function http_handler_func = + std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); + ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); } } // namespace comm diff --git a/tests/ut/cpp/ps/comm/tcp_client_tests.cc b/tests/ut/cpp/ps/comm/tcp_client_tests.cc new file mode 100644 index 0000000000..424e7cc286 --- /dev/null +++ b/tests/ut/cpp/ps/comm/tcp_client_tests.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "ps/comm/tcp_client.h" + +namespace mindspore { +namespace ps { +namespace comm { +class TestTcpClient : public UT::Common { + public: + TestTcpClient() = default; +}; + +TEST_F(TestTcpClient, InitClientIPError) { + auto client = new TcpClient("127.0.0.13543", 9000); + client->ReceiveMessage( + [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); + + ASSERT_THROW(client->InitTcpClient(), std::exception); +} + +TEST_F(TestTcpClient, InitClientPortErrorNoException) { + auto client = new TcpClient("127.0.0.1", -1); + client->ReceiveMessage( + [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); + + EXPECT_NO_THROW(client->InitTcpClient()); +} + +} // namespace comm +} // namespace ps +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/comm/tcp_server_tests.cc b/tests/ut/cpp/ps/comm/tcp_server_tests.cc new file mode 100644 index 0000000000..c3c8ac96d8 --- /dev/null +++ b/tests/ut/cpp/ps/comm/tcp_server_tests.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/tcp_client.h" +#include "ps/comm/tcp_server.h" +#include "common/common_test.h" + +#include + +namespace mindspore { +namespace ps { +namespace comm { +class TestTcpServer : public UT::Common { + public: + TestTcpServer() = default; + void SetUp() override { + server_ = new TcpServer("127.0.0.1", 9000); + std::unique_ptr http_server_thread_(nullptr); + http_server_thread_ = std::make_unique([&]() { + server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) { + EXPECT_STREQ(std::string(reinterpret_cast(buffer), num).c_str(), "TCP_MESSAGE"); + server.SendMessage(conn, buffer, num); + }); + server_->InitServer(); + server_->Start(); + }); + http_server_thread_->detach(); + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + } + void TearDown() override { + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + client_->Stop(); + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + server_->Stop(); + } + + TcpClient *client_; + TcpServer *server_; + const std::string test_message_ = "TCP_MESSAGE"; +}; + +TEST_F(TestTcpServer, ServerSendMessage) { + client_ = new TcpClient("127.0.0.1", 9000); + std::unique_ptr http_client_thread(nullptr); + http_client_thread = std::make_unique([&]() { + client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) { + EXPECT_STREQ(std::string(reinterpret_cast(buffer), num).c_str(), "TCP_MESSAGE"); + }); + + client_->InitTcpClient(); + client_->SendMessage(test_message_.c_str(), test_message_.size()); + client_->Start(); + }); + http_client_thread->detach(); +} +} // namespace comm +} // namespace ps +} // namespace mindspore \ No newline at end of file