!7518 Added tcp server based on libevent
	
		
	
				
					
				
			Merge pull request !7518 from anancds/tcp-serverpull/7518/MERGE
						commit
						80725441b6
					
				@ -0,0 +1,50 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "ps/comm/comm_util.h"
 | 
				
			||||
 | 
				
			||||
#include <arpa/inet.h>
 | 
				
			||||
#include <cstdio>
 | 
				
			||||
#include <cstdlib>
 | 
				
			||||
#include <cstring>
 | 
				
			||||
#include <functional>
 | 
				
			||||
#include <regex>
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
 | 
				
			||||
bool CommUtil::CheckIpWithRegex(const std::string &ip) {
 | 
				
			||||
  std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
 | 
				
			||||
  std::smatch res;
 | 
				
			||||
  if (regex_match(ip, res, pattern)) {
 | 
				
			||||
    return true;
 | 
				
			||||
  }
 | 
				
			||||
  return false;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void CommUtil::CheckIp(const std::string &ip) {
 | 
				
			||||
  if (!CheckIpWithRegex(ip)) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!";
 | 
				
			||||
  }
 | 
				
			||||
  int64_t uAddr = inet_addr(ip.c_str());
 | 
				
			||||
  if (INADDR_NONE == uAddr) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!";
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,49 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
 | 
				
			||||
 | 
				
			||||
#include <event2/buffer.h>
 | 
				
			||||
#include <event2/event.h>
 | 
				
			||||
#include <event2/http.h>
 | 
				
			||||
#include <event2/keyvalq_struct.h>
 | 
				
			||||
#include <event2/listener.h>
 | 
				
			||||
#include <event2/util.h>
 | 
				
			||||
 | 
				
			||||
#include <cstdio>
 | 
				
			||||
#include <cstdlib>
 | 
				
			||||
#include <cstring>
 | 
				
			||||
#include <functional>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <utility>
 | 
				
			||||
 | 
				
			||||
#include "utils/log_adapter.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
 | 
				
			||||
class CommUtil {
 | 
				
			||||
 public:
 | 
				
			||||
  static bool CheckIpWithRegex(const std::string &ip);
 | 
				
			||||
  static void CheckIp(const std::string &ip);
 | 
				
			||||
};
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
 | 
				
			||||
@ -0,0 +1,220 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "ps/comm/tcp_client.h"
 | 
				
			||||
 | 
				
			||||
#include <arpa/inet.h>
 | 
				
			||||
#include <event2/buffer.h>
 | 
				
			||||
#include <event2/bufferevent.h>
 | 
				
			||||
#include <event2/buffer_compat.h>
 | 
				
			||||
#include <event2/event.h>
 | 
				
			||||
#include <netinet/in.h>
 | 
				
			||||
#include <netinet/tcp.h>
 | 
				
			||||
#include <sys/socket.h>
 | 
				
			||||
#include <cstdlib>
 | 
				
			||||
#include <cstring>
 | 
				
			||||
#include <iostream>
 | 
				
			||||
#include <utility>
 | 
				
			||||
#include <string>
 | 
				
			||||
 | 
				
			||||
#include "ps/comm/comm_util.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
 | 
				
			||||
TcpClient::TcpClient(std::string address, std::uint16_t port)
 | 
				
			||||
    : event_base_(nullptr),
 | 
				
			||||
      event_timeout_(nullptr),
 | 
				
			||||
      buffer_event_(nullptr),
 | 
				
			||||
      server_address_(std::move(address)),
 | 
				
			||||
      server_port_(port) {
 | 
				
			||||
  message_handler_.SetCallback([this](const void *buf, size_t num) {
 | 
				
			||||
    if (buf == nullptr) {
 | 
				
			||||
      if (disconnected_callback_) disconnected_callback_(*this, 200);
 | 
				
			||||
      Stop();
 | 
				
			||||
    }
 | 
				
			||||
    if (message_callback_) message_callback_(*this, buf, num);
 | 
				
			||||
  });
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TcpClient::~TcpClient() { Stop(); }
 | 
				
			||||
 | 
				
			||||
std::string TcpClient::GetServerAddress() const { return server_address_; }
 | 
				
			||||
 | 
				
			||||
void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
 | 
				
			||||
                            const OnTimeout &timeout) {
 | 
				
			||||
  connected_callback_ = conn;
 | 
				
			||||
  disconnected_callback_ = disconn;
 | 
				
			||||
  read_callback_ = read;
 | 
				
			||||
  timeout_callback_ = timeout;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::InitTcpClient() {
 | 
				
			||||
  if (buffer_event_) {
 | 
				
			||||
    return;
 | 
				
			||||
  }
 | 
				
			||||
  CommUtil::CheckIp(server_address_);
 | 
				
			||||
 | 
				
			||||
  event_base_ = event_base_new();
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(event_base_);
 | 
				
			||||
 | 
				
			||||
  sockaddr_in sin{};
 | 
				
			||||
  if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
 | 
				
			||||
  }
 | 
				
			||||
  sin.sin_family = AF_INET;
 | 
				
			||||
  sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
 | 
				
			||||
  sin.sin_port = htons(server_port_);
 | 
				
			||||
 | 
				
			||||
  buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE);
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(buffer_event_);
 | 
				
			||||
 | 
				
			||||
  bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
 | 
				
			||||
  if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
 | 
				
			||||
  if (result_code < 0) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::StartWithDelay(int seconds) {
 | 
				
			||||
  if (buffer_event_) {
 | 
				
			||||
    return;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  event_base_ = event_base_new();
 | 
				
			||||
 | 
				
			||||
  timeval timeout_value{};
 | 
				
			||||
  timeout_value.tv_sec = seconds;
 | 
				
			||||
  timeout_value.tv_usec = 0;
 | 
				
			||||
 | 
				
			||||
  event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
 | 
				
			||||
  if (evtimer_add(event_timeout_, &timeout_value) == -1) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "event timeout failed!";
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::Stop() {
 | 
				
			||||
  if (buffer_event_) {
 | 
				
			||||
    bufferevent_free(buffer_event_);
 | 
				
			||||
    buffer_event_ = nullptr;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  if (event_timeout_) {
 | 
				
			||||
    event_free(event_timeout_);
 | 
				
			||||
    event_timeout_ = nullptr;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  if (event_base_) {
 | 
				
			||||
    event_base_free(event_base_);
 | 
				
			||||
    event_base_ = nullptr;
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
 | 
				
			||||
  const int one = 1;
 | 
				
			||||
  int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
 | 
				
			||||
  if (ret < 0) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "Set socket no delay failed!";
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(arg);
 | 
				
			||||
  auto tcp_client = reinterpret_cast<TcpClient *>(arg);
 | 
				
			||||
  tcp_client->InitTcpClient();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(bev);
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(ctx);
 | 
				
			||||
  auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
 | 
				
			||||
  struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev));
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(input);
 | 
				
			||||
 | 
				
			||||
  char read_buffer[4096];
 | 
				
			||||
  int read = 0;
 | 
				
			||||
 | 
				
			||||
  while ((read = EVBUFFER_LENGTH(input)) > 0) {
 | 
				
			||||
    if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) {
 | 
				
			||||
      MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
 | 
				
			||||
    }
 | 
				
			||||
    tcp_client->OnReadHandler(read_buffer, read);
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::OnReadHandler(const void *buf, size_t num) {
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(buf);
 | 
				
			||||
  if (read_callback_) {
 | 
				
			||||
    read_callback_(*this, buf, num);
 | 
				
			||||
  }
 | 
				
			||||
  message_handler_.ReceiveMessage(buf, num);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(bev);
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(ptr);
 | 
				
			||||
  auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
 | 
				
			||||
  if (events & BEV_EVENT_CONNECTED) {
 | 
				
			||||
    // Connected
 | 
				
			||||
    if (tcp_client->connected_callback_) {
 | 
				
			||||
      tcp_client->connected_callback_(*tcp_client);
 | 
				
			||||
    }
 | 
				
			||||
    evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev));
 | 
				
			||||
    SetTcpNoDelay(fd);
 | 
				
			||||
    MS_LOG(INFO) << "Client connected!";
 | 
				
			||||
  } else if (events & BEV_EVENT_ERROR) {
 | 
				
			||||
    MS_LOG(ERROR) << "Client connected error!";
 | 
				
			||||
    if (tcp_client->disconnected_callback_) {
 | 
				
			||||
      tcp_client->disconnected_callback_(*tcp_client, errno);
 | 
				
			||||
    }
 | 
				
			||||
  } else if (events & BEV_EVENT_EOF) {
 | 
				
			||||
    MS_LOG(ERROR) << "Client connected end of file";
 | 
				
			||||
    if (tcp_client->disconnected_callback_) {
 | 
				
			||||
      tcp_client->disconnected_callback_(*tcp_client, 0);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::Start() {
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(event_base_);
 | 
				
			||||
  int ret = event_base_dispatch(event_base_);
 | 
				
			||||
  if (ret == 0) {
 | 
				
			||||
    MS_LOG(INFO) << "Event base dispatch success!";
 | 
				
			||||
  } else if (ret == 1) {
 | 
				
			||||
    MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
 | 
				
			||||
  } else if (ret == -1) {
 | 
				
			||||
    MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
 | 
				
			||||
  } else {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; }
 | 
				
			||||
 | 
				
			||||
void TcpClient::SendMessage(const void *buf, size_t num) const {
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(buffer_event_);
 | 
				
			||||
  if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) {
 | 
				
			||||
    MS_LOG(EXCEPTION) << "event buffer add failed!";
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,77 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
 | 
				
			||||
 | 
				
			||||
#include "ps/comm/tcp_message_handler.h"
 | 
				
			||||
 | 
				
			||||
#include <event2/event.h>
 | 
				
			||||
#include <event2/bufferevent.h>
 | 
				
			||||
#include <functional>
 | 
				
			||||
#include <string>
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
 | 
				
			||||
class TcpClient {
 | 
				
			||||
 public:
 | 
				
			||||
  using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>;
 | 
				
			||||
  using OnConnected = std::function<void(const TcpClient &)>;
 | 
				
			||||
  using OnDisconnected = std::function<void(const TcpClient &, int)>;
 | 
				
			||||
  using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
 | 
				
			||||
  using OnTimeout = std::function<void(const TcpClient &)>;
 | 
				
			||||
 | 
				
			||||
  explicit TcpClient(std::string address, std::uint16_t port);
 | 
				
			||||
  virtual ~TcpClient();
 | 
				
			||||
 | 
				
			||||
  std::string GetServerAddress() const;
 | 
				
			||||
  void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
 | 
				
			||||
                   const OnTimeout &timeout);
 | 
				
			||||
  void InitTcpClient();
 | 
				
			||||
  void StartWithDelay(int seconds);
 | 
				
			||||
  void Stop();
 | 
				
			||||
  void ReceiveMessage(const OnMessage &cb);
 | 
				
			||||
  void SendMessage(const void *buf, size_t num) const;
 | 
				
			||||
  void Start();
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  static void SetTcpNoDelay(const evutil_socket_t &fd);
 | 
				
			||||
  static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg);
 | 
				
			||||
  static void ReadCallback(struct bufferevent *bev, void *ctx);
 | 
				
			||||
  static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
 | 
				
			||||
  virtual void OnReadHandler(const void *buf, size_t num);
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  TcpMessageHandler message_handler_;
 | 
				
			||||
  OnMessage message_callback_;
 | 
				
			||||
  OnConnected connected_callback_;
 | 
				
			||||
  OnDisconnected disconnected_callback_;
 | 
				
			||||
  OnRead read_callback_;
 | 
				
			||||
  OnTimeout timeout_callback_;
 | 
				
			||||
 | 
				
			||||
  event_base *event_base_;
 | 
				
			||||
  event *event_timeout_;
 | 
				
			||||
  bufferevent *buffer_event_;
 | 
				
			||||
 | 
				
			||||
  std::string server_address_;
 | 
				
			||||
  std::uint16_t server_port_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
 | 
				
			||||
@ -0,0 +1,36 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "ps/comm/tcp_message_handler.h"
 | 
				
			||||
#include <iostream>
 | 
				
			||||
#include <utility>
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
 | 
				
			||||
void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); }
 | 
				
			||||
 | 
				
			||||
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
 | 
				
			||||
  MS_EXCEPTION_IF_NULL(buffer);
 | 
				
			||||
 | 
				
			||||
  if (message_callback_) {
 | 
				
			||||
    message_callback_(buffer, num);
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,47 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
 | 
				
			||||
 | 
				
			||||
#include <functional>
 | 
				
			||||
#include <iostream>
 | 
				
			||||
#include <memory>
 | 
				
			||||
 | 
				
			||||
#include "utils/log_adapter.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
 | 
				
			||||
using messageReceive = std::function<void(const void *buffer, size_t len)>;
 | 
				
			||||
 | 
				
			||||
class TcpMessageHandler {
 | 
				
			||||
 public:
 | 
				
			||||
  TcpMessageHandler() = default;
 | 
				
			||||
  virtual ~TcpMessageHandler() = default;
 | 
				
			||||
 | 
				
			||||
  void SetCallback(messageReceive cb);
 | 
				
			||||
  void ReceiveMessage(const void *buffer, size_t num);
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  messageReceive message_callback_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,107 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
 | 
				
			||||
#define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
 | 
				
			||||
 | 
				
			||||
#include <event2/buffer.h>
 | 
				
			||||
#include <event2/bufferevent.h>
 | 
				
			||||
#include <event2/event.h>
 | 
				
			||||
#include <event2/listener.h>
 | 
				
			||||
#include <exception>
 | 
				
			||||
#include <functional>
 | 
				
			||||
#include <iostream>
 | 
				
			||||
#include <map>
 | 
				
			||||
#include <mutex>
 | 
				
			||||
#include <string>
 | 
				
			||||
 | 
				
			||||
#include "utils/log_adapter.h"
 | 
				
			||||
#include "ps/comm/tcp_message_handler.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
 | 
				
			||||
class TcpServer;
 | 
				
			||||
class TcpConnection {
 | 
				
			||||
 public:
 | 
				
			||||
  TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {}
 | 
				
			||||
  virtual ~TcpConnection() = default;
 | 
				
			||||
 | 
				
			||||
  virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server);
 | 
				
			||||
  void SendMessage(const void *buffer, size_t num) const;
 | 
				
			||||
  virtual void OnReadHandler(const void *buffer, size_t numBytes);
 | 
				
			||||
  TcpServer *GetServer() const;
 | 
				
			||||
  evutil_socket_t GetFd() const;
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  TcpMessageHandler tcp_message_handler_;
 | 
				
			||||
  struct bufferevent *buffer_event_;
 | 
				
			||||
  evutil_socket_t fd_;
 | 
				
			||||
  TcpServer *server_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
using OnServerReceiveMessage =
 | 
				
			||||
  std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>;
 | 
				
			||||
 | 
				
			||||
class TcpServer {
 | 
				
			||||
 public:
 | 
				
			||||
  using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>;
 | 
				
			||||
  using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>;
 | 
				
			||||
  using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>;
 | 
				
			||||
 | 
				
			||||
  explicit TcpServer(std::string address, std::uint16_t port);
 | 
				
			||||
  virtual ~TcpServer();
 | 
				
			||||
 | 
				
			||||
  void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
 | 
				
			||||
                         const OnAccepted &client_accept);
 | 
				
			||||
  void InitServer();
 | 
				
			||||
  void Start();
 | 
				
			||||
  void Stop();
 | 
				
			||||
  void SendToAllClients(const char *data, size_t len);
 | 
				
			||||
  void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
 | 
				
			||||
  void RemoveConnection(const evutil_socket_t &fd);
 | 
				
			||||
  void ReceiveMessage(const OnServerReceiveMessage &cb);
 | 
				
			||||
  static void SendMessage(const TcpConnection &conn, const void *data, size_t num);
 | 
				
			||||
  void SendMessage(const void *data, size_t num);
 | 
				
			||||
  OnServerReceiveMessage GetServerReceiveMessage() const;
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
 | 
				
			||||
                               int socklen, void *server);
 | 
				
			||||
  static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
 | 
				
			||||
  static void ReadCallback(struct bufferevent *, void *connection);
 | 
				
			||||
  static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
 | 
				
			||||
  virtual TcpConnection *onCreateConnection();
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  struct event_base *base_;
 | 
				
			||||
  struct event *signal_event_;
 | 
				
			||||
  struct evconnlistener *listener_;
 | 
				
			||||
  std::string server_address_;
 | 
				
			||||
  std::uint16_t server_port_;
 | 
				
			||||
 | 
				
			||||
  std::map<evutil_socket_t, const TcpConnection *> connections_;
 | 
				
			||||
  OnConnected client_connection_;
 | 
				
			||||
  OnDisconnected client_disconnection_;
 | 
				
			||||
  OnAccepted client_accept_;
 | 
				
			||||
  std::recursive_mutex connection_mutex_;
 | 
				
			||||
  OnServerReceiveMessage message_callback_;
 | 
				
			||||
};
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
#endif  // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
 | 
				
			||||
@ -0,0 +1,46 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "common/common_test.h"
 | 
				
			||||
#include "ps/comm/tcp_client.h"
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
class TestTcpClient : public UT::Common {
 | 
				
			||||
 public:
 | 
				
			||||
  TestTcpClient() = default;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
TEST_F(TestTcpClient, InitClientIPError) {
 | 
				
			||||
  auto client = new TcpClient("127.0.0.13543", 9000);
 | 
				
			||||
  client->ReceiveMessage(
 | 
				
			||||
    [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
 | 
				
			||||
 | 
				
			||||
  ASSERT_THROW(client->InitTcpClient(), std::exception);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
 | 
				
			||||
  auto client = new TcpClient("127.0.0.1", -1);
 | 
				
			||||
  client->ReceiveMessage(
 | 
				
			||||
    [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
 | 
				
			||||
 | 
				
			||||
  EXPECT_NO_THROW(client->InitTcpClient());
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
@ -0,0 +1,71 @@
 | 
				
			||||
/**
 | 
				
			||||
 * Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
 *
 | 
				
			||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
 * you may not use this file except in compliance with the License.
 | 
				
			||||
 * You may obtain a copy of the License at
 | 
				
			||||
 *
 | 
				
			||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 *
 | 
				
			||||
 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
 * See the License for the specific language governing permissions and
 | 
				
			||||
 * limitations under the License.
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include "ps/comm/tcp_client.h"
 | 
				
			||||
#include "ps/comm/tcp_server.h"
 | 
				
			||||
#include "common/common_test.h"
 | 
				
			||||
 | 
				
			||||
#include <thread>
 | 
				
			||||
 | 
				
			||||
namespace mindspore {
 | 
				
			||||
namespace ps {
 | 
				
			||||
namespace comm {
 | 
				
			||||
class TestTcpServer : public UT::Common {
 | 
				
			||||
 public:
 | 
				
			||||
  TestTcpServer() = default;
 | 
				
			||||
  void SetUp() override {
 | 
				
			||||
    server_ = new TcpServer("127.0.0.1", 9000);
 | 
				
			||||
    std::unique_ptr<std::thread> http_server_thread_(nullptr);
 | 
				
			||||
    http_server_thread_ = std::make_unique<std::thread>([&]() {
 | 
				
			||||
      server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) {
 | 
				
			||||
        EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
 | 
				
			||||
        server.SendMessage(conn, buffer, num);
 | 
				
			||||
      });
 | 
				
			||||
      server_->InitServer();
 | 
				
			||||
      server_->Start();
 | 
				
			||||
    });
 | 
				
			||||
    http_server_thread_->detach();
 | 
				
			||||
    std::this_thread::sleep_for(std::chrono::milliseconds(2000));
 | 
				
			||||
  }
 | 
				
			||||
  void TearDown() override {
 | 
				
			||||
    std::this_thread::sleep_for(std::chrono::milliseconds(2000));
 | 
				
			||||
    client_->Stop();
 | 
				
			||||
    std::this_thread::sleep_for(std::chrono::milliseconds(2000));
 | 
				
			||||
    server_->Stop();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  TcpClient *client_;
 | 
				
			||||
  TcpServer *server_;
 | 
				
			||||
  const std::string test_message_ = "TCP_MESSAGE";
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
TEST_F(TestTcpServer, ServerSendMessage) {
 | 
				
			||||
  client_ = new TcpClient("127.0.0.1", 9000);
 | 
				
			||||
  std::unique_ptr<std::thread> http_client_thread(nullptr);
 | 
				
			||||
  http_client_thread = std::make_unique<std::thread>([&]() {
 | 
				
			||||
    client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) {
 | 
				
			||||
      EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
 | 
				
			||||
    });
 | 
				
			||||
 | 
				
			||||
    client_->InitTcpClient();
 | 
				
			||||
    client_->SendMessage(test_message_.c_str(), test_message_.size());
 | 
				
			||||
    client_->Start();
 | 
				
			||||
  });
 | 
				
			||||
  http_client_thread->detach();
 | 
				
			||||
}
 | 
				
			||||
}  // namespace comm
 | 
				
			||||
}  // namespace ps
 | 
				
			||||
}  // namespace mindspore
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue