added message handler unit test

pull/8335/head
anancds 5 years ago
parent 8aa78c2c8e
commit 96d8c411e7

@ -100,8 +100,8 @@ message("onnx proto path is :" ${ONNX_PROTO})
ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO})
list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS})
include_directories("${CMAKE_BINARY_DIR}/ps/comm")
file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto")
include_directories("${CMAKE_BINARY_DIR}/ps/core")
file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/core/protos/*.proto")
ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN})
list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS})

@ -5,12 +5,13 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc")
list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "util.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_client.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc")
endif()
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)

@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/cluster_config.h"
#include <string>
namespace mindspore {
namespace ps {
namespace core {
uint32_t ClusterConfig::worker_num_ = 0;
uint32_t ClusterConfig::server_num_ = 0;
uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval;
std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr;
uint16_t ClusterConfig::scheduler_port_ = 0;
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num,
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) {
worker_num_ = worker_num;
server_num_ = server_num;
if (!CommUtil::CheckIp(*scheduler_host.get())) {
MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!";
}
scheduler_host_ = std::move(scheduler_host);
scheduler_port_ = scheduler_port;
}
uint32_t ClusterConfig::worker_num() { return worker_num_; }
uint32_t ClusterConfig::server_num() { return server_num_; }
uint32_t ClusterConfig::heartbeat_interval() { return heartbeat_interval_; }
void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) {
heartbeat_interval_ = heartbeat_interval;
}
std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); }
uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; }
} // namespace core
} // namespace ps
} // namespace mindspore

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
#define MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
#include <string>
#include <iostream>
#include <memory>
#include <utility>
#include "utils/log_adapter.h"
#include "ps/core/comm_util.h"
namespace mindspore {
namespace ps {
namespace core {
constexpr uint32_t kHeartbeatInterval = 3;
class ClusterConfig {
public:
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host,
const uint16_t &scheduler_port);
static uint32_t worker_num();
static uint32_t server_num();
static uint32_t heartbeat_interval();
static void set_heartbeat_interval(const uint32_t &heartbeat_interval);
static std::string scheduler_host();
static uint16_t scheduler_port();
private:
static uint32_t worker_num_;
static uint32_t server_num_;
static uint32_t heartbeat_interval_;
static std::unique_ptr<std::string> scheduler_host_;
static uint16_t scheduler_port_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "ps/comm/comm_util.h"
#include "ps/core/comm_util.h"
#include <arpa/inet.h>
#include <cstdio>
@ -25,7 +25,7 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
bool CommUtil::CheckIpWithRegex(const std::string &ip) {
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
@ -36,15 +36,47 @@ bool CommUtil::CheckIpWithRegex(const std::string &ip) {
return false;
}
void CommUtil::CheckIp(const std::string &ip) {
bool CommUtil::CheckIp(const std::string &ip) {
if (!CheckIpWithRegex(ip)) {
MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!";
return false;
}
int64_t uAddr = inet_addr(ip.c_str());
if (INADDR_NONE == uAddr) {
MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!";
return false;
}
return true;
}
} // namespace comm
void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *ip) {
MS_EXCEPTION_IF_NULL(interface);
MS_EXCEPTION_IF_NULL(ip);
struct ifaddrs *if_address = nullptr;
struct ifaddrs *ifa = nullptr;
interface->clear();
ip->clear();
getifaddrs(&if_address);
for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == nullptr) {
continue;
}
if (ifa->ifa_addr->sa_family == AF_INET && (ifa->ifa_flags & IFF_LOOPBACK) == 0) {
char address_buffer[INET_ADDRSTRLEN] = {0};
void *sin_addr_ptr = &(reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr))->sin_addr;
MS_EXCEPTION_IF_NULL(sin_addr_ptr);
const char *net_ptr = inet_ntop(AF_INET, sin_addr_ptr, address_buffer, INET_ADDRSTRLEN);
MS_EXCEPTION_IF_NULL(net_ptr);
*ip = address_buffer;
*interface = ifa->ifa_name;
break;
}
}
MS_EXCEPTION_IF_NULL(if_address);
freeifaddrs(if_address);
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -14,8 +14,21 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
#define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
#ifndef MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_
#define MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_
#include <unistd.h>
#ifdef _MSC_VER
#include <tchar.h>
#include <winsock2.h>
#include <windows.h>
#include <iphlpapi.h>
#else
#include <net/if.h>
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <netinet/in.h>
#endif
#include <event2/buffer.h>
#include <event2/event.h>
@ -35,15 +48,16 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
class CommUtil {
public:
static bool CheckIpWithRegex(const std::string &ip);
static void CheckIp(const std::string &ip);
static bool CheckIp(const std::string &ip);
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
};
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
#endif // MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "ps/comm/http_message_handler.h"
#include "ps/core/http_message_handler.h"
#include <event2/event.h>
#include <event2/buffer.h>
@ -36,7 +36,7 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
void HttpMessageHandler::InitHttpMessage() {
MS_EXCEPTION_IF_NULL(event_request_);
@ -202,6 +202,6 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) {
evhttp_send_error(event_request_, nCode, message.c_str());
}
}
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_
#define MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_
#ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_
#define MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_
#include <event2/buffer.h>
#include <event2/event.h>
@ -36,7 +36,7 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
using HttpHeaders = std::map<std::string, std::list<std::string>>;
@ -101,7 +101,7 @@ class HttpMessageHandler {
void ParsePostParam();
};
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_
#endif // MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_

@ -14,9 +14,9 @@
* limitations under the License.
*/
#include "ps/comm/http_server.h"
#include "ps/comm/http_message_handler.h"
#include "ps/comm/comm_util.h"
#include "ps/core/http_server.h"
#include "ps/core/http_message_handler.h"
#include "ps/core/comm_util.h"
#ifdef WIN32
#include <WinSock2.h>
@ -40,12 +40,14 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
HttpServer::~HttpServer() { Stop(); }
bool HttpServer::InitServer() {
CommUtil::CheckIp(server_address_);
if (!CommUtil::CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!";
}
event_base_ = event_base_new();
MS_EXCEPTION_IF_NULL(event_base_);
@ -154,6 +156,6 @@ void HttpServer::Stop() {
}
}
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore

@ -14,10 +14,10 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_
#define MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_
#ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_
#define MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_
#include "ps/comm/http_message_handler.h"
#include "ps/core/http_message_handler.h"
#include <event2/buffer.h>
#include <event2/event.h>
@ -35,7 +35,7 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
typedef enum eHttpMethod {
HM_GET = 1 << 0,
@ -86,8 +86,8 @@ class HttpServer {
bool is_init_;
};
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_
#endif // MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_

@ -16,9 +16,24 @@
syntax = "proto3";
import "google/protobuf/any.proto";
package mindspore.ps;
package mindspore.ps.core;
option optimize_for = LITE_RUNTIME;
enum ClusterCommand {
TERMINATE = 0;
REGISTER = 1;
ACK = 2;
HEARTBEAT = 3;
FETCH_WORKERS = 4;
FETCH_SERVERS = 5;
}
enum Role {
SERVER = 0;
WORKER = 1;
SCHEDULER = 2;
}
message MessageMeta {
// hostname or ip
string hostname = 1;

@ -14,6 +14,10 @@
* limitations under the License.
*/
syntax = "proto3";
package mindspore.ps.core;
option optimize_for = LITE_RUNTIME;
message KVMessage {
repeated int32 keys = 1;
repeated float values = 2;

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "ps/comm/tcp_client.h"
#include "ps/core/tcp_client.h"
#include <arpa/inet.h>
#include <event2/buffer.h>
@ -30,11 +30,11 @@
#include <utility>
#include <string>
#include "ps/comm/comm_util.h"
#include "ps/core/comm_util.h"
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
TcpClient::TcpClient(const std::string &address, std::uint16_t port)
: event_base_(nullptr),
@ -65,7 +65,9 @@ void TcpClient::Init() {
if (buffer_event_) {
return;
}
CommUtil::CheckIp(server_address_);
if (!CommUtil::CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
}
event_base_ = event_base_new();
MS_EXCEPTION_IF_NULL(event_base_);
@ -166,6 +168,23 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) {
message_handler_.ReceiveMessage(buf, num);
}
void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
MessageMeta meta;
meta.set_cmd(ClusterCommand::HEARTBEAT);
CommMessage message;
message.set_allocated_pb_meta(&meta);
tcp_client->SendMessage(message);
struct event *ev;
struct timeval timeout {};
timeout.tv_sec = ClusterConfig::heartbeat_interval();
timeout.tv_usec = 0;
ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg);
evtimer_add(ev, &timeout);
}
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(ptr);
@ -226,6 +245,16 @@ void TcpClient::SendMessage(const CommMessage &message) const {
}
}
} // namespace comm
void TcpClient::SendMessageWithTimer() {
MS_EXCEPTION_IF_NULL(buffer_event_);
struct event *ev = nullptr;
struct timeval timeout {};
timeout.tv_sec = 0;
timeout.tv_usec = 0;
ev = evtimer_new(event_base_, SendHeartBeatCallback, this);
evtimer_add(ev, &timeout);
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -14,10 +14,10 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
#define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
#ifndef MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_
#define MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_
#include "ps/comm/tcp_message_handler.h"
#include "ps/core/tcp_message_handler.h"
#include <event2/event.h>
#include <event2/bufferevent.h>
@ -27,10 +27,11 @@
#include <vector>
#include "proto/comm.pb.h"
#include "ps/core/cluster_config.h"
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
class TcpClient {
public:
@ -53,6 +54,7 @@ class TcpClient {
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
void SendMessage(const CommMessage &message) const;
void SendMessageWithTimer();
protected:
static void SetTcpNoDelay(const evutil_socket_t &fd);
@ -60,6 +62,7 @@ class TcpClient {
static void ReadCallback(struct bufferevent *bev, void *ctx);
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
virtual void OnReadHandler(const void *buf, size_t num);
static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg);
private:
OnMessage message_callback_;
@ -78,7 +81,7 @@ class TcpClient {
std::uint16_t server_port_;
};
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
#endif // MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "ps/comm/tcp_message_handler.h"
#include "ps/core/tcp_message_handler.h"
#include <arpa/inet.h>
#include <iostream>
@ -22,7 +22,7 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; }
@ -37,16 +37,15 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
--num;
if (header_index_ == 3) {
message_length_ = *reinterpret_cast<const uint32_t *>(header_);
message_length_ = ntohl(message_length_);
remaining_length_ = message_length_;
message_buffer_.reset(new unsigned char[remaining_length_]);
buffer_data += i;
buffer_data += (i + 1);
break;
}
}
}
if (remaining_length_ > 0) {
if (remaining_length_ > 0 && num > 0) {
uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
remaining_length_ -= copy_len;
num -= copy_len;
@ -60,19 +59,19 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
if (remaining_length_ == 0) {
CommMessage pb_message;
pb_message.ParseFromArray(reinterpret_cast<const void *>(message_buffer_.get()), message_length_);
pb_message.ParseFromArray(message_buffer_.get(), message_length_);
if (message_callback_) {
message_callback_(pb_message);
}
message_buffer_.reset();
message_buffer_ = nullptr;
header_index_ = 0;
header_index_ = -1;
last_copy_len_ = 0;
}
}
}
}
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
#define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
#ifndef MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_
#define MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_
#include <functional>
#include <iostream>
@ -29,7 +29,7 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
using messageReceive = std::function<void(const CommMessage &message)>;
@ -57,8 +57,8 @@ class TcpMessageHandler {
int header_index_;
uint32_t last_copy_len_;
};
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
#endif // MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "ps/comm/tcp_server.h"
#include "ps/core/tcp_server.h"
#include <arpa/inet.h>
#include <event2/buffer.h>
@ -27,11 +27,11 @@
#include <csignal>
#include <utility>
#include "ps/comm/comm_util.h"
#include "ps/core/comm_util.h"
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
void TcpConnection::InitConnection() {
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
@ -88,7 +88,9 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
void TcpServer::Init() {
base_ = event_base_new();
MS_EXCEPTION_IF_NULL(base_);
CommUtil::CheckIp(server_address_);
if (!CommUtil::CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!";
}
struct sockaddr_in sin {};
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
@ -104,6 +106,18 @@ void TcpServer::Init() {
MS_EXCEPTION_IF_NULL(listener_);
if (server_port_ == 0) {
struct sockaddr_in sin_bound {};
if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) {
MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
}
socklen_t addr_len = sizeof(struct sockaddr_in);
if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) {
MS_LOG(EXCEPTION) << "Get sock name failed!";
}
server_port_ = htons(sin_bound.sin_port);
}
signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
MS_EXCEPTION_IF_NULL(signal_event_);
if (event_add(signal_event_, nullptr) < 0) {
@ -173,11 +187,13 @@ void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
connections_.erase(fd);
}
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) {
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
void *data) {
auto server = reinterpret_cast<class TcpServer *>(data);
auto base = reinterpret_cast<struct event_base *>(server->base_);
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(base);
MS_EXCEPTION_IF_NULL(sockaddr);
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE);
if (!bev) {
@ -279,8 +295,10 @@ void TcpServer::SendMessage(const CommMessage &message) {
}
}
uint16_t TcpServer::BoundPort() const { return server_port_; }
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
#define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
#ifndef MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_
#define MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_
#include <event2/buffer.h>
#include <event2/bufferevent.h>
@ -31,11 +31,11 @@
#include <vector>
#include "utils/log_adapter.h"
#include "ps/comm/tcp_message_handler.h"
#include "ps/core/tcp_message_handler.h"
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
class TcpServer;
class TcpConnection {
@ -83,6 +83,7 @@ class TcpServer {
void SetMessageCallback(const OnServerReceiveMessage &cb);
static void SendMessage(const TcpConnection &conn, const CommMessage &message);
void SendMessage(const CommMessage &message);
uint16_t BoundPort() const;
protected:
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
@ -106,7 +107,7 @@ class TcpServer {
OnServerReceiveMessage message_callback_;
};
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
#endif // MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "common/common_test.h"
#include "ps/core/cluster_config.h"
namespace mindspore {
namespace ps {
namespace core {
class TestClusterConfig : public UT::Common {
public:
TestClusterConfig() = default;
virtual ~TestClusterConfig() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestClusterConfig, HeartbeatInterval) {
ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080);
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3);
ClusterConfig::set_heartbeat_interval(100);
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100);
EXPECT_STREQ(ClusterConfig::scheduler_host().c_str(), "127.0.0.1");
EXPECT_TRUE(ClusterConfig::scheduler_port() == 8080);
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "ps/core/comm_util.h"
#include <memory>
#include <thread>
namespace mindspore {
namespace ps {
namespace core {
class TestCommUtil : public UT::Common {
public:
TestCommUtil() = default;
virtual ~TestCommUtil() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
std::string interface;
std::string ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &ip);
EXPECT_TRUE(!interface.empty());
EXPECT_TRUE(!ip.empty());
}
} // namespace comm
} // namespace ps
} // namespace mindspore

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "ps/comm/http_server.h"
#include "ps/core/http_server.h"
#include "common/common_test.h"
#include <gtest/gtest.h>
#include <algorithm>
@ -28,7 +28,7 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
class TestHttpServer : public UT::Common {
public:

@ -17,11 +17,11 @@
#include <memory>
#include "common/common_test.h"
#include "ps/comm/tcp_client.h"
#include "ps/core/tcp_client.h"
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
class TestTcpClient : public UT::Common {
public:
TestTcpClient() = default;

@ -0,0 +1,163 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/tcp_message_handler.h"
#include "common/common_test.h"
#include <memory>
#include <thread>
namespace mindspore {
namespace ps {
namespace core {
class TestTcpMessageHandler : public UT::Common {
public:
using messageReceive = std::function<void(const CommMessage &message)>;
TestTcpMessageHandler() = default;
virtual ~TestTcpMessageHandler() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[1007];
int ret = memcpy_s(result, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
handler.ReceiveMessage(result, buf_size + 4);
}
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[2014];
int ret = memcpy_s(result, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 2 * buf_size + 4 * 2);
}
TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); });
std::string data(4087, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[4096];
int ret = memcpy_s(result, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4096);
ret = memcpy_s(result, 2, &buf_size + 2, 2);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4092);
}
TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); });
std::string data(4085, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[4096];
int ret = memcpy_s(result, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4096);
ret = memcpy_s(result, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4088);
}
} // namespace comm
} // namespace ps
} // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "ps/comm/tcp_client.h"
#include "ps/comm/tcp_server.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "common/common_test.h"
#include <memory>
@ -23,14 +23,14 @@
namespace mindspore {
namespace ps {
namespace comm {
namespace core {
class TestTcpServer : public UT::Common {
public:
TestTcpServer() : client_(nullptr), server_(nullptr) {}
virtual ~TestTcpServer() = default;
void SetUp() override {
server_ = std::make_unique<TcpServer>("127.0.0.1", 9998);
server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([&]() {
server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
@ -57,7 +57,7 @@ class TestTcpServer : public UT::Common {
};
TEST_F(TestTcpServer, ServerSendMessage) {
client_ = std::make_unique<TcpClient>("127.0.0.1", 9998);
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort());
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
@ -82,6 +82,6 @@ TEST_F(TestTcpServer, ServerSendMessage) {
});
http_client_thread->detach();
}
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore
Loading…
Cancel
Save