!7518 Added tcp server based on libevent
Merge pull request !7518 from anancds/tcp-serverpull/7518/MERGE
commit
80725441b6
@ -0,0 +1,50 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/comm/comm_util.h"
|
||||||
|
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <functional>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
|
||||||
|
bool CommUtil::CheckIpWithRegex(const std::string &ip) {
|
||||||
|
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
|
||||||
|
std::smatch res;
|
||||||
|
if (regex_match(ip, res, pattern)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommUtil::CheckIp(const std::string &ip) {
|
||||||
|
if (!CheckIpWithRegex(ip)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!";
|
||||||
|
}
|
||||||
|
int64_t uAddr = inet_addr(ip.c_str());
|
||||||
|
if (INADDR_NONE == uAddr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,49 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
|
||||||
|
|
||||||
|
#include <event2/buffer.h>
|
||||||
|
#include <event2/event.h>
|
||||||
|
#include <event2/http.h>
|
||||||
|
#include <event2/keyvalq_struct.h>
|
||||||
|
#include <event2/listener.h>
|
||||||
|
#include <event2/util.h>
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
|
||||||
|
class CommUtil {
|
||||||
|
public:
|
||||||
|
static bool CheckIpWithRegex(const std::string &ip);
|
||||||
|
static void CheckIp(const std::string &ip);
|
||||||
|
};
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
|
@ -0,0 +1,220 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/comm/tcp_client.h"
|
||||||
|
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <event2/buffer.h>
|
||||||
|
#include <event2/bufferevent.h>
|
||||||
|
#include <event2/buffer_compat.h>
|
||||||
|
#include <event2/event.h>
|
||||||
|
#include <netinet/in.h>
|
||||||
|
#include <netinet/tcp.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "ps/comm/comm_util.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
|
||||||
|
TcpClient::TcpClient(std::string address, std::uint16_t port)
|
||||||
|
: event_base_(nullptr),
|
||||||
|
event_timeout_(nullptr),
|
||||||
|
buffer_event_(nullptr),
|
||||||
|
server_address_(std::move(address)),
|
||||||
|
server_port_(port) {
|
||||||
|
message_handler_.SetCallback([this](const void *buf, size_t num) {
|
||||||
|
if (buf == nullptr) {
|
||||||
|
if (disconnected_callback_) disconnected_callback_(*this, 200);
|
||||||
|
Stop();
|
||||||
|
}
|
||||||
|
if (message_callback_) message_callback_(*this, buf, num);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TcpClient::~TcpClient() { Stop(); }
|
||||||
|
|
||||||
|
std::string TcpClient::GetServerAddress() const { return server_address_; }
|
||||||
|
|
||||||
|
void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
|
||||||
|
const OnTimeout &timeout) {
|
||||||
|
connected_callback_ = conn;
|
||||||
|
disconnected_callback_ = disconn;
|
||||||
|
read_callback_ = read;
|
||||||
|
timeout_callback_ = timeout;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::InitTcpClient() {
|
||||||
|
if (buffer_event_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
CommUtil::CheckIp(server_address_);
|
||||||
|
|
||||||
|
event_base_ = event_base_new();
|
||||||
|
MS_EXCEPTION_IF_NULL(event_base_);
|
||||||
|
|
||||||
|
sockaddr_in sin{};
|
||||||
|
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
|
||||||
|
MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
|
||||||
|
}
|
||||||
|
sin.sin_family = AF_INET;
|
||||||
|
sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
|
||||||
|
sin.sin_port = htons(server_port_);
|
||||||
|
|
||||||
|
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE);
|
||||||
|
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||||
|
|
||||||
|
bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
|
||||||
|
if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
|
||||||
|
}
|
||||||
|
|
||||||
|
int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
|
||||||
|
if (result_code < 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::StartWithDelay(int seconds) {
|
||||||
|
if (buffer_event_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
event_base_ = event_base_new();
|
||||||
|
|
||||||
|
timeval timeout_value{};
|
||||||
|
timeout_value.tv_sec = seconds;
|
||||||
|
timeout_value.tv_usec = 0;
|
||||||
|
|
||||||
|
event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
|
||||||
|
if (evtimer_add(event_timeout_, &timeout_value) == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "event timeout failed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::Stop() {
|
||||||
|
if (buffer_event_) {
|
||||||
|
bufferevent_free(buffer_event_);
|
||||||
|
buffer_event_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (event_timeout_) {
|
||||||
|
event_free(event_timeout_);
|
||||||
|
event_timeout_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (event_base_) {
|
||||||
|
event_base_free(event_base_);
|
||||||
|
event_base_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
|
||||||
|
const int one = 1;
|
||||||
|
int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
|
||||||
|
if (ret < 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
|
||||||
|
MS_EXCEPTION_IF_NULL(arg);
|
||||||
|
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
|
||||||
|
tcp_client->InitTcpClient();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(bev);
|
||||||
|
MS_EXCEPTION_IF_NULL(ctx);
|
||||||
|
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
|
||||||
|
struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev));
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
|
||||||
|
char read_buffer[4096];
|
||||||
|
int read = 0;
|
||||||
|
|
||||||
|
while ((read = EVBUFFER_LENGTH(input)) > 0) {
|
||||||
|
if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
|
||||||
|
}
|
||||||
|
tcp_client->OnReadHandler(read_buffer, read);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::OnReadHandler(const void *buf, size_t num) {
|
||||||
|
MS_EXCEPTION_IF_NULL(buf);
|
||||||
|
if (read_callback_) {
|
||||||
|
read_callback_(*this, buf, num);
|
||||||
|
}
|
||||||
|
message_handler_.ReceiveMessage(buf, num);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(bev);
|
||||||
|
MS_EXCEPTION_IF_NULL(ptr);
|
||||||
|
auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
|
||||||
|
if (events & BEV_EVENT_CONNECTED) {
|
||||||
|
// Connected
|
||||||
|
if (tcp_client->connected_callback_) {
|
||||||
|
tcp_client->connected_callback_(*tcp_client);
|
||||||
|
}
|
||||||
|
evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev));
|
||||||
|
SetTcpNoDelay(fd);
|
||||||
|
MS_LOG(INFO) << "Client connected!";
|
||||||
|
} else if (events & BEV_EVENT_ERROR) {
|
||||||
|
MS_LOG(ERROR) << "Client connected error!";
|
||||||
|
if (tcp_client->disconnected_callback_) {
|
||||||
|
tcp_client->disconnected_callback_(*tcp_client, errno);
|
||||||
|
}
|
||||||
|
} else if (events & BEV_EVENT_EOF) {
|
||||||
|
MS_LOG(ERROR) << "Client connected end of file";
|
||||||
|
if (tcp_client->disconnected_callback_) {
|
||||||
|
tcp_client->disconnected_callback_(*tcp_client, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::Start() {
|
||||||
|
MS_EXCEPTION_IF_NULL(event_base_);
|
||||||
|
int ret = event_base_dispatch(event_base_);
|
||||||
|
if (ret == 0) {
|
||||||
|
MS_LOG(INFO) << "Event base dispatch success!";
|
||||||
|
} else if (ret == 1) {
|
||||||
|
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
|
||||||
|
} else if (ret == -1) {
|
||||||
|
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; }
|
||||||
|
|
||||||
|
void TcpClient::SendMessage(const void *buf, size_t num) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||||
|
if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "event buffer add failed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,77 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
|
||||||
|
|
||||||
|
#include "ps/comm/tcp_message_handler.h"
|
||||||
|
|
||||||
|
#include <event2/event.h>
|
||||||
|
#include <event2/bufferevent.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
|
||||||
|
class TcpClient {
|
||||||
|
public:
|
||||||
|
using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||||
|
using OnConnected = std::function<void(const TcpClient &)>;
|
||||||
|
using OnDisconnected = std::function<void(const TcpClient &, int)>;
|
||||||
|
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||||
|
using OnTimeout = std::function<void(const TcpClient &)>;
|
||||||
|
|
||||||
|
explicit TcpClient(std::string address, std::uint16_t port);
|
||||||
|
virtual ~TcpClient();
|
||||||
|
|
||||||
|
std::string GetServerAddress() const;
|
||||||
|
void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
|
||||||
|
const OnTimeout &timeout);
|
||||||
|
void InitTcpClient();
|
||||||
|
void StartWithDelay(int seconds);
|
||||||
|
void Stop();
|
||||||
|
void ReceiveMessage(const OnMessage &cb);
|
||||||
|
void SendMessage(const void *buf, size_t num) const;
|
||||||
|
void Start();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static void SetTcpNoDelay(const evutil_socket_t &fd);
|
||||||
|
static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg);
|
||||||
|
static void ReadCallback(struct bufferevent *bev, void *ctx);
|
||||||
|
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
|
||||||
|
virtual void OnReadHandler(const void *buf, size_t num);
|
||||||
|
|
||||||
|
private:
|
||||||
|
TcpMessageHandler message_handler_;
|
||||||
|
OnMessage message_callback_;
|
||||||
|
OnConnected connected_callback_;
|
||||||
|
OnDisconnected disconnected_callback_;
|
||||||
|
OnRead read_callback_;
|
||||||
|
OnTimeout timeout_callback_;
|
||||||
|
|
||||||
|
event_base *event_base_;
|
||||||
|
event *event_timeout_;
|
||||||
|
bufferevent *buffer_event_;
|
||||||
|
|
||||||
|
std::string server_address_;
|
||||||
|
std::uint16_t server_port_;
|
||||||
|
};
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
|
@ -0,0 +1,36 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/comm/tcp_message_handler.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
|
||||||
|
void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); }
|
||||||
|
|
||||||
|
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
||||||
|
MS_EXCEPTION_IF_NULL(buffer);
|
||||||
|
|
||||||
|
if (message_callback_) {
|
||||||
|
message_callback_(buffer, num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
|
||||||
|
using messageReceive = std::function<void(const void *buffer, size_t len)>;
|
||||||
|
|
||||||
|
class TcpMessageHandler {
|
||||||
|
public:
|
||||||
|
TcpMessageHandler() = default;
|
||||||
|
virtual ~TcpMessageHandler() = default;
|
||||||
|
|
||||||
|
void SetCallback(messageReceive cb);
|
||||||
|
void ReceiveMessage(const void *buffer, size_t num);
|
||||||
|
|
||||||
|
private:
|
||||||
|
messageReceive message_callback_;
|
||||||
|
};
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,107 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
|
||||||
|
|
||||||
|
#include <event2/buffer.h>
|
||||||
|
#include <event2/bufferevent.h>
|
||||||
|
#include <event2/event.h>
|
||||||
|
#include <event2/listener.h>
|
||||||
|
#include <exception>
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "ps/comm/tcp_message_handler.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
|
||||||
|
class TcpServer;
|
||||||
|
class TcpConnection {
|
||||||
|
public:
|
||||||
|
TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {}
|
||||||
|
virtual ~TcpConnection() = default;
|
||||||
|
|
||||||
|
virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server);
|
||||||
|
void SendMessage(const void *buffer, size_t num) const;
|
||||||
|
virtual void OnReadHandler(const void *buffer, size_t numBytes);
|
||||||
|
TcpServer *GetServer() const;
|
||||||
|
evutil_socket_t GetFd() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TcpMessageHandler tcp_message_handler_;
|
||||||
|
struct bufferevent *buffer_event_;
|
||||||
|
evutil_socket_t fd_;
|
||||||
|
TcpServer *server_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using OnServerReceiveMessage =
|
||||||
|
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>;
|
||||||
|
|
||||||
|
class TcpServer {
|
||||||
|
public:
|
||||||
|
using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>;
|
||||||
|
using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>;
|
||||||
|
using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>;
|
||||||
|
|
||||||
|
explicit TcpServer(std::string address, std::uint16_t port);
|
||||||
|
virtual ~TcpServer();
|
||||||
|
|
||||||
|
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
|
||||||
|
const OnAccepted &client_accept);
|
||||||
|
void InitServer();
|
||||||
|
void Start();
|
||||||
|
void Stop();
|
||||||
|
void SendToAllClients(const char *data, size_t len);
|
||||||
|
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
|
||||||
|
void RemoveConnection(const evutil_socket_t &fd);
|
||||||
|
void ReceiveMessage(const OnServerReceiveMessage &cb);
|
||||||
|
static void SendMessage(const TcpConnection &conn, const void *data, size_t num);
|
||||||
|
void SendMessage(const void *data, size_t num);
|
||||||
|
OnServerReceiveMessage GetServerReceiveMessage() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
|
||||||
|
int socklen, void *server);
|
||||||
|
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
|
||||||
|
static void ReadCallback(struct bufferevent *, void *connection);
|
||||||
|
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
|
||||||
|
virtual TcpConnection *onCreateConnection();
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct event_base *base_;
|
||||||
|
struct event *signal_event_;
|
||||||
|
struct evconnlistener *listener_;
|
||||||
|
std::string server_address_;
|
||||||
|
std::uint16_t server_port_;
|
||||||
|
|
||||||
|
std::map<evutil_socket_t, const TcpConnection *> connections_;
|
||||||
|
OnConnected client_connection_;
|
||||||
|
OnDisconnected client_disconnection_;
|
||||||
|
OnAccepted client_accept_;
|
||||||
|
std::recursive_mutex connection_mutex_;
|
||||||
|
OnServerReceiveMessage message_callback_;
|
||||||
|
};
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
|
@ -0,0 +1,46 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "ps/comm/tcp_client.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
class TestTcpClient : public UT::Common {
|
||||||
|
public:
|
||||||
|
TestTcpClient() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTcpClient, InitClientIPError) {
|
||||||
|
auto client = new TcpClient("127.0.0.13543", 9000);
|
||||||
|
client->ReceiveMessage(
|
||||||
|
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
|
||||||
|
|
||||||
|
ASSERT_THROW(client->InitTcpClient(), std::exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
|
||||||
|
auto client = new TcpClient("127.0.0.1", -1);
|
||||||
|
client->ReceiveMessage(
|
||||||
|
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
|
||||||
|
|
||||||
|
EXPECT_NO_THROW(client->InitTcpClient());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,71 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/comm/tcp_client.h"
|
||||||
|
#include "ps/comm/tcp_server.h"
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace comm {
|
||||||
|
class TestTcpServer : public UT::Common {
|
||||||
|
public:
|
||||||
|
TestTcpServer() = default;
|
||||||
|
void SetUp() override {
|
||||||
|
server_ = new TcpServer("127.0.0.1", 9000);
|
||||||
|
std::unique_ptr<std::thread> http_server_thread_(nullptr);
|
||||||
|
http_server_thread_ = std::make_unique<std::thread>([&]() {
|
||||||
|
server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) {
|
||||||
|
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
|
||||||
|
server.SendMessage(conn, buffer, num);
|
||||||
|
});
|
||||||
|
server_->InitServer();
|
||||||
|
server_->Start();
|
||||||
|
});
|
||||||
|
http_server_thread_->detach();
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||||
|
}
|
||||||
|
void TearDown() override {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||||
|
client_->Stop();
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||||
|
server_->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
TcpClient *client_;
|
||||||
|
TcpServer *server_;
|
||||||
|
const std::string test_message_ = "TCP_MESSAGE";
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTcpServer, ServerSendMessage) {
|
||||||
|
client_ = new TcpClient("127.0.0.1", 9000);
|
||||||
|
std::unique_ptr<std::thread> http_client_thread(nullptr);
|
||||||
|
http_client_thread = std::make_unique<std::thread>([&]() {
|
||||||
|
client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) {
|
||||||
|
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
|
||||||
|
});
|
||||||
|
|
||||||
|
client_->InitTcpClient();
|
||||||
|
client_->SendMessage(test_message_.c_str(), test_message_.size());
|
||||||
|
client_->Start();
|
||||||
|
});
|
||||||
|
http_client_thread->detach();
|
||||||
|
}
|
||||||
|
} // namespace comm
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
Loading…
Reference in new issue