From 40f2571f7e82b854f6ce877e7e4e99ccf9b7422c Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Wed, 3 Feb 2021 18:07:14 +0800 Subject: [PATCH] added http client --- mindspore/ccsrc/ps/CMakeLists.txt | 1 + mindspore/ccsrc/ps/core/comm_util.h | 6 +- mindspore/ccsrc/ps/core/http_client.cc | 226 ++++++++++++++++++ mindspore/ccsrc/ps/core/http_client.h | 97 ++++++++ .../ccsrc/ps/core/http_message_handler.cc | 73 +++++- .../ccsrc/ps/core/http_message_handler.h | 39 ++- mindspore/ccsrc/ps/core/http_server.cc | 9 +- mindspore/ccsrc/ps/core/http_server.h | 19 +- mindspore/ccsrc/ps/core/tcp_client.cc | 2 +- mindspore/ccsrc/ps/core/tcp_server.cc | 2 +- tests/ut/cpp/ps/core/http_client_test.cc | 121 ++++++++++ tests/ut/cpp/ps/core/http_server_test.cc | 2 - 12 files changed, 564 insertions(+), 33 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/http_client.cc create mode 100644 mindspore/ccsrc/ps/core/http_client.h create mode 100644 tests/ut/cpp/ps/core/http_client_test.cc diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 0c9c04874d..29f1e8cee2 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -20,6 +20,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc") endif() if(NOT ENABLE_D) diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index 41fb373741..8ba8efd3a2 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -61,6 +61,11 @@ constexpr int kGroup3RandomLength = 4; constexpr int kGroup4RandomLength = 4; constexpr int kGroup5RandomLength = 12; +// The size of the buffer for sending and receiving data is 4096 bytes. +constexpr int kMessageChunkLength = 4096; +// The timeout period for the http client to connect to the http server is 120 seconds. +constexpr int kConnectionTimeout = 120; + class CommUtil { public: static bool CheckIpWithRegex(const std::string &ip); @@ -80,5 +85,4 @@ class CommUtil { } // namespace core } // namespace ps } // namespace mindspore - #endif // MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ diff --git a/mindspore/ccsrc/ps/core/http_client.cc b/mindspore/ccsrc/ps/core/http_client.cc new file mode 100644 index 0000000000..a66a25b6c1 --- /dev/null +++ b/mindspore/ccsrc/ps/core/http_client.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/http_client.h" + +namespace mindspore { +namespace ps { +namespace core { +HttpClient::~HttpClient() { + if (event_base_ != nullptr) { + event_base_free(event_base_); + event_base_ = nullptr; + } +} + +void HttpClient::Init() { + event_base_ = event_base_new(); + MS_EXCEPTION_IF_NULL(event_base_); + dns_base_ = evdns_base_new(event_base_, 1); + MS_EXCEPTION_IF_NULL(dns_base_); +} + +Status HttpClient::Post(const std::string &url, const void *body, size_t len, std::shared_ptr> output, + const std::map &headers) { + MS_EXCEPTION_IF_NULL(body); + MS_EXCEPTION_IF_NULL(output); + auto handler = std::make_shared(); + output->clear(); + handler->set_body(output); + + struct evhttp_request *request = evhttp_request_new(ReadCallback, reinterpret_cast(handler.get())); + MS_EXCEPTION_IF_NULL(request); + + InitRequest(handler, url, request); + + struct evhttp_connection *connection = + evhttp_connection_base_new(event_base_, dns_base_, handler->GetHostByUri(), handler->GetUriPort()); + if (!connection) { + MS_LOG(ERROR) << "Create http connection failed!"; + return Status::BADREQUEST; + } + + struct evbuffer *buffer = evhttp_request_get_output_buffer(request); + if (evbuffer_add(buffer, body, len) != 0) { + MS_LOG(ERROR) << "Add buffer failed!"; + return Status::INTERNAL; + } + + AddHeaders(headers, request, handler); + + return CreateRequest(handler, connection, request, HttpMethod::HM_POST); +} + +Status HttpClient::Get(const std::string &url, std::shared_ptr> output, + const std::map &headers) { + MS_EXCEPTION_IF_NULL(output); + auto handler = std::make_shared(); + output->clear(); + handler->set_body(output); + + struct evhttp_request *request = evhttp_request_new(ReadCallback, reinterpret_cast(handler.get())); + MS_EXCEPTION_IF_NULL(request); + + InitRequest(handler, url, request); + + struct evhttp_connection *connection = + evhttp_connection_base_new(event_base_, dns_base_, handler->GetHostByUri(), handler->GetUriPort()); + if (!connection) { + MS_LOG(ERROR) << "Create http connection failed!"; + return Status::BADREQUEST; + } + + AddHeaders(headers, request, handler); + + return CreateRequest(handler, connection, request, HttpMethod::HM_GET); +} + +void HttpClient::set_connection_timeout(const int &timeout) { connection_timout_ = timeout; } + +void HttpClient::ReadCallback(struct evhttp_request *request, void *arg) { + MS_EXCEPTION_IF_NULL(request); + MS_EXCEPTION_IF_NULL(arg); + auto handler = static_cast(arg); + if (event_base_loopexit(handler->http_base(), nullptr) != 0) { + MS_LOG(EXCEPTION) << "event base loop exit failed!"; + } +} + +int HttpClient::ReadHeaderDoneCallback(struct evhttp_request *request, void *arg) { + MS_EXCEPTION_IF_NULL(request); + MS_EXCEPTION_IF_NULL(arg); + auto handler = static_cast(arg); + handler->set_request(request); + MS_LOG(DEBUG) << "The http response code is:" << evhttp_request_get_response_code(request) + << ", The request code line is:" << evhttp_request_get_response_code_line(request); + struct evkeyvalq *headers = evhttp_request_get_input_headers(request); + struct evkeyval *header; + TAILQ_FOREACH(header, headers, next) { + MS_LOG(DEBUG) << "The key:" << header->key << ",The value:" << header->value; + std::string len = "Content-Length"; + if (!strcmp(header->key, len.c_str())) { + handler->set_content_len(strtouq(header->value, nullptr, 10)); + handler->InitBodySize(); + } + } + return 0; +} + +void HttpClient::ReadChunkDataCallback(struct evhttp_request *request, void *arg) { + MS_EXCEPTION_IF_NULL(request); + MS_EXCEPTION_IF_NULL(arg); + auto handler = static_cast(arg); + char buf[kMessageChunkLength]; + struct evbuffer *evbuf = evhttp_request_get_input_buffer(request); + MS_EXCEPTION_IF_NULL(evbuf); + int n = 0; + while ((n = evbuffer_remove(evbuf, &buf, sizeof(buf))) > 0) { + handler->ReceiveMessage(buf, n); + } +} + +void HttpClient::RequestErrorCallback(enum evhttp_request_error error, void *arg) { + MS_EXCEPTION_IF_NULL(arg); + auto handler = static_cast(arg); + MS_LOG(ERROR) << "The request failed, the error is:" << error; + if (event_base_loopexit(handler->http_base(), nullptr) != 0) { + MS_LOG(EXCEPTION) << "event base loop exit failed!"; + } +} + +void HttpClient::ConnectionCloseCallback(struct evhttp_connection *connection, void *arg) { + MS_EXCEPTION_IF_NULL(connection); + MS_EXCEPTION_IF_NULL(arg); + MS_LOG(ERROR) << "Remote connection closed!"; + if (event_base_loopexit((struct event_base *)arg, nullptr) != 0) { + MS_LOG(EXCEPTION) << "event base loop exit failed!"; + } +} + +void HttpClient::AddHeaders(const std::map &headers, struct evhttp_request *request, + std::shared_ptr handler) { + MS_EXCEPTION_IF_NULL(request); + if (evhttp_add_header(evhttp_request_get_output_headers(request), "Host", handler->GetHostByUri()) != 0) { + MS_LOG(EXCEPTION) << "Add header failed!"; + } + for (auto &header : headers) { + if (evhttp_add_header(evhttp_request_get_output_headers(request), header.first.data(), header.second.data()) != 0) { + MS_LOG(EXCEPTION) << "Add header failed!"; + } + } +} + +void HttpClient::InitRequest(std::shared_ptr handler, const std::string &url, + struct evhttp_request *request) { + MS_EXCEPTION_IF_NULL(request); + MS_EXCEPTION_IF_NULL(handler); + handler->set_http_base(event_base_); + handler->ParseUrl(url); + evhttp_request_set_header_cb(request, ReadHeaderDoneCallback); + evhttp_request_set_chunked_cb(request, ReadChunkDataCallback); + evhttp_request_set_error_cb(request, RequestErrorCallback); + + MS_LOG(DEBUG) << "The url is:" << url << ", The host is:" << handler->GetHostByUri() + << ", The port is:" << handler->GetUriPort() << ", The request_url is:" << handler->GetRequestPath(); +} + +Status HttpClient::CreateRequest(std::shared_ptr handler, struct evhttp_connection *connection, + struct evhttp_request *request, HttpMethod method) { + MS_EXCEPTION_IF_NULL(handler); + MS_EXCEPTION_IF_NULL(connection); + MS_EXCEPTION_IF_NULL(request); + evhttp_connection_set_closecb(connection, ConnectionCloseCallback, event_base_); + evhttp_connection_set_timeout(connection, connection_timout_); + + if (evhttp_make_request(connection, request, evhttp_cmd_type(method), handler->GetRequestPath().c_str()) != 0) { + MS_LOG(ERROR) << "Make request failed!"; + return Status::INTERNAL; + } + + if (!Start()) { + MS_LOG(ERROR) << "Start http client failed!"; + return Status::INTERNAL; + } + + if (handler->request()) { + MS_LOG(DEBUG) << "The http response code is:" << evhttp_request_get_response_code(handler->request()) + << ", The request code line is:" << evhttp_request_get_response_code_line(handler->request()); + return Status(evhttp_request_get_response_code(handler->request())); + } + return Status::INTERNAL; +} + +bool HttpClient::Start() { + MS_EXCEPTION_IF_NULL(event_base_); + // int ret = event_base_dispatch(event_base_); + int ret = event_base_loop(event_base_, 0); + if (ret == 0) { + MS_LOG(DEBUG) << "Event base dispatch success!"; + return true; + } else if (ret == 1) { + MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; + return false; + } else if (ret == -1) { + MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; + return false; + } else { + MS_LOG(EXCEPTION) << "Event base dispatch with unexpected error code!"; + } + return true; +} +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/http_client.h b/mindspore/ccsrc/ps/core/http_client.h new file mode 100644 index 0000000000..efef6d8ac2 --- /dev/null +++ b/mindspore/ccsrc/ps/core/http_client.h @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_CLIENT_H_ +#define MINDSPORE_CCSRC_PS_CORE_HTTP_CLIENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ps/core/http_message_handler.h" +#include "ps/core/comm_util.h" + +namespace mindspore { +namespace ps { +namespace core { + +enum class HttpMethod { HM_GET = 1 << 0, HM_POST = 1 << 1 }; + +enum class Status : int { + OK = 200, // request completed ok + BADREQUEST = 400, // invalid http request was made + NOTFOUND = 404, // could not find content for uri + INTERNAL = 500 // internal error +}; + +class HttpClient { + public: + HttpClient() : event_base_(nullptr), dns_base_(nullptr), is_init_(false), connection_timout_(kConnectionTimeout) { + Init(); + } + + virtual ~HttpClient(); + + Status Post(const std::string &url, const void *body, size_t len, std::shared_ptr> output, + const std::map &headers = {}); + Status Get(const std::string &url, std::shared_ptr> output, + const std::map &headers = {}); + + void set_connection_timeout(const int &timeout); + + private: + static void ReadCallback(struct evhttp_request *remote_rsp, void *arg); + static int ReadHeaderDoneCallback(struct evhttp_request *remote_rsp, void *arg); + static void ReadChunkDataCallback(struct evhttp_request *remote_rsp, void *arg); + static void RequestErrorCallback(enum evhttp_request_error error, void *arg); + static void ConnectionCloseCallback(struct evhttp_connection *connection, void *arg); + + void AddHeaders(const std::map &headers, struct evhttp_request *request, + std::shared_ptr handler); + void InitRequest(std::shared_ptr handler, const std::string &url, struct evhttp_request *request); + Status CreateRequest(std::shared_ptr handler, struct evhttp_connection *connection, + struct evhttp_request *request, HttpMethod method); + + bool Start(); + void Init(); + + struct event_base *event_base_; + struct evdns_base *dns_base_; + bool is_init_; + int connection_timout_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_CORE_HTTP_CLIENT_H_ diff --git a/mindspore/ccsrc/ps/core/http_message_handler.cc b/mindspore/ccsrc/ps/core/http_message_handler.cc index a3b3351c91..aa5a7fe2d8 100644 --- a/mindspore/ccsrc/ps/core/http_message_handler.cc +++ b/mindspore/ccsrc/ps/core/http_message_handler.cc @@ -44,6 +44,7 @@ void HttpMessageHandler::InitHttpMessage() { const char *query = evhttp_uri_get_query(event_uri_); if (query) { + MS_LOG(WARNING) << "The query is:" << query; evhttp_parse_query_str(query, &path_params_); } @@ -52,6 +53,11 @@ void HttpMessageHandler::InitHttpMessage() { resp_buf_ = evhttp_request_get_output_buffer(event_request_); } +void HttpMessageHandler::ParseUrl(const std::string &url) { + event_uri_ = evhttp_uri_parse(url.c_str()); + MS_EXCEPTION_IF_NULL(event_uri_); +} + std::string HttpMessageHandler::GetHeadParam(const std::string &key) { MS_EXCEPTION_IF_NULL(head_params_); const char *val = evhttp_find_header(head_params_, key.c_str()); @@ -74,8 +80,8 @@ void HttpMessageHandler::ParsePostParam() { post_param_parsed_ = true; const char *post_message = reinterpret_cast(evbuffer_pullup(event_request_->input_buffer, -1)); MS_EXCEPTION_IF_NULL(post_message); - body_ = std::make_unique(post_message, len); - int ret = evhttp_parse_query_str(body_->c_str(), &post_params_); + post_message_ = std::make_unique(post_message, len); + int ret = evhttp_parse_query_str(post_message_->c_str(), &post_params_); if (ret == -1) { MS_LOG(EXCEPTION) << "Parse post parameter failed!"; } @@ -105,9 +111,20 @@ std::string HttpMessageHandler::GetRequestHost() { return std::string(host); } +const char *HttpMessageHandler::GetHostByUri() { + MS_EXCEPTION_IF_NULL(event_uri_); + const char *host = evhttp_uri_get_host(event_uri_); + MS_EXCEPTION_IF_NULL(host); + return host; +} + int HttpMessageHandler::GetUriPort() { MS_EXCEPTION_IF_NULL(event_uri_); - return evhttp_uri_get_port(event_uri_); + int port = evhttp_uri_get_port(event_uri_); + if (port < 0) { + MS_LOG(EXCEPTION) << "The port:" << port << " should not be less than 0!"; + } + return port; } std::string HttpMessageHandler::GetUriPath() { @@ -117,6 +134,21 @@ std::string HttpMessageHandler::GetUriPath() { return std::string(path); } +std::string HttpMessageHandler::GetRequestPath() { + MS_EXCEPTION_IF_NULL(event_uri_); + const char *path = evhttp_uri_get_path(event_uri_); + if (path == nullptr || strlen(path) == 0) { + path = "/"; + } + std::string path_res(path); + const char *query = evhttp_uri_get_query(event_uri_); + if (query) { + path_res.append("?"); + path_res.append(query); + } + return path_res; +} + std::string HttpMessageHandler::GetUriQuery() { MS_EXCEPTION_IF_NULL(event_uri_); const char *query = evhttp_uri_get_query(event_uri_); @@ -202,6 +234,41 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) { evhttp_send_error(event_request_, nCode, message.c_str()); } } + +void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { + MS_EXCEPTION_IF_NULL(buffer); + int ret = memcpy_s(body_->data() + offset_, num, buffer, num); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + offset_ += num; +} + +void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; } + +uint64_t HttpMessageHandler::content_len() { return content_len_; } + +event_base *HttpMessageHandler::http_base() { return event_base_; } + +void HttpMessageHandler::set_http_base(const struct event_base *base) { + MS_EXCEPTION_IF_NULL(base); + event_base_ = const_cast(base); +} + +void HttpMessageHandler::set_request(const struct evhttp_request *req) { + MS_EXCEPTION_IF_NULL(req); + event_request_ = const_cast(req); +} + +struct evhttp_request *HttpMessageHandler::request() { + return event_request_; +} + +void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); } + +std::shared_ptr> HttpMessageHandler::body() { return body_; } + +void HttpMessageHandler::set_body(std::shared_ptr> body) { body_ = body; } } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/http_message_handler.h b/mindspore/ccsrc/ps/core/http_message_handler.h index 3b7e571e7d..1c2f2a7da5 100644 --- a/mindspore/ccsrc/ps/core/http_message_handler.h +++ b/mindspore/ccsrc/ps/core/http_message_handler.h @@ -32,37 +32,49 @@ #include #include #include +#include + +#include "ps/core/comm_util.h" #include "utils/log_adapter.h" namespace mindspore { namespace ps { namespace core { using HttpHeaders = std::map>; +using VectorPtr = std::shared_ptr>; class HttpMessageHandler { public: - explicit HttpMessageHandler(struct evhttp_request *req) - : event_request_(req), + HttpMessageHandler() + : event_request_(nullptr), event_uri_(nullptr), path_params_{0}, head_params_(nullptr), post_params_{0}, post_param_parsed_(false), + post_message_(nullptr), body_(nullptr), resp_headers_(nullptr), resp_buf_(nullptr), - resp_code_(HTTP_OK) {} + resp_code_(HTTP_OK), + content_len_(0), + event_base_(nullptr), + offset_(0) {} virtual ~HttpMessageHandler() = default; void InitHttpMessage(); + void ParseUrl(const std::string &url); + std::string GetRequestUri(); std::string GetRequestHost(); + const char *GetHostByUri(); std::string GetHeadParam(const std::string &key); std::string GetPathParam(const std::string &key); std::string GetPostParam(const std::string &key); uint64_t GetPostMsg(unsigned char **buffer); std::string GetUriPath(); + std::string GetRequestPath(); std::string GetUriQuery(); // It will return -1 if no port set @@ -83,6 +95,18 @@ class HttpMessageHandler { // If message is empty, libevent will use default error code message instead void RespError(int nCode, const std::string &message); + // Body length should no more than MAX_POST_BODY_LEN, default 64kB + void ParsePostParam(); + void ReceiveMessage(const void *buffer, size_t num); + void set_content_len(const uint64_t &len); + uint64_t content_len(); + event_base *http_base(); + void set_http_base(const struct event_base *base); + void set_request(const struct evhttp_request *req); + struct evhttp_request *request(); + void InitBodySize(); + VectorPtr body(); + void set_body(VectorPtr body); private: struct evhttp_request *event_request_; @@ -91,13 +115,14 @@ class HttpMessageHandler { struct evkeyvalq *head_params_; struct evkeyvalq post_params_; bool post_param_parsed_; - std::unique_ptr body_; + std::unique_ptr post_message_; + VectorPtr body_; struct evkeyvalq *resp_headers_; struct evbuffer *resp_buf_; int resp_code_; - - // Body length should no more than MAX_POST_BODY_LEN, default 64kB - void ParsePostParam(); + uint64_t content_len_; + struct event_base *event_base_; + uint64_t offset_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/http_server.cc b/mindspore/ccsrc/ps/core/http_server.cc index 5ed0376f8f..36a3862175 100644 --- a/mindspore/ccsrc/ps/core/http_server.cc +++ b/mindspore/ccsrc/ps/core/http_server.cc @@ -57,6 +57,7 @@ bool HttpServer::InitServer() { MS_EXCEPTION_IF_NULL(event_base_); event_http_ = evhttp_new(event_base_); MS_EXCEPTION_IF_NULL(event_http_); + evhttp_set_timeout(event_http_, request_timeout_); int ret = evhttp_bind_socket(event_http_, server_address_.c_str(), server_port_); if (ret != 0) { MS_LOG(EXCEPTION) << "Http bind server addr:" << server_address_.c_str() << " port:" << server_port_ << "failed"; @@ -70,7 +71,7 @@ void HttpServer::SetTimeOut(int seconds) { if (seconds < 0) { MS_LOG(EXCEPTION) << "The timeout seconds:" << seconds << "is less than 0!"; } - evhttp_set_timeout(event_http_, seconds); + request_timeout_ = seconds; } void HttpServer::SetAllowedMethod(u_int16_t methods) { @@ -105,7 +106,8 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio auto TransFunc = [](struct evhttp_request *req, void *arg) { MS_EXCEPTION_IF_NULL(req); MS_EXCEPTION_IF_NULL(arg); - auto httpReq = std::make_shared(req); + auto httpReq = std::make_shared(); + httpReq->set_request(req); httpReq->InitHttpMessage(); OnRequestReceive *func = reinterpret_cast(arg); (*func)(httpReq); @@ -144,8 +146,9 @@ bool HttpServer::Start() { MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; return false; } else { - MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; + MS_LOG(EXCEPTION) << "Event base dispatch with unexpected error code!"; } + return true; } void HttpServer::Stop() { diff --git a/mindspore/ccsrc/ps/core/http_server.h b/mindspore/ccsrc/ps/core/http_server.h index 9d95f33faf..fd05d6a229 100644 --- a/mindspore/ccsrc/ps/core/http_server.h +++ b/mindspore/ccsrc/ps/core/http_server.h @@ -38,18 +38,6 @@ namespace mindspore { namespace ps { namespace core { -typedef enum eHttpMethod { - HM_GET = 1 << 0, - HM_POST = 1 << 1, - HM_HEAD = 1 << 2, - HM_PUT = 1 << 3, - HM_DELETE = 1 << 4, - HM_OPTIONS = 1 << 5, - HM_TRACE = 1 << 6, - HM_CONNECT = 1 << 7, - HM_PATCH = 1 << 8 -} HttpMethod; - using OnRequestReceive = std::function)>; class HttpServer { @@ -61,12 +49,13 @@ class HttpServer { event_base_(nullptr), event_http_(nullptr), is_init_(false), - is_stop_(true) {} + is_stop_(true), + request_timeout_(300) {} ~HttpServer(); bool InitServer(); - void SetTimeOut(int seconds = 5); + void SetTimeOut(int seconds); // Default allowed methods: GET, POST, HEAD, PUT, DELETE void SetAllowedMethod(u_int16_t methods); @@ -91,9 +80,9 @@ class HttpServer { struct evhttp *event_http_; bool is_init_; std::atomic is_stop_; + int request_timeout_; }; } // namespace core } // namespace ps } // namespace mindspore - #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index a030fc9de5..d0d29121dd 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -176,7 +176,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { struct evbuffer *input = bufferevent_get_input(const_cast(bev)); MS_EXCEPTION_IF_NULL(input); - char read_buffer[4096]; + char read_buffer[kMessageChunkLength]; while (EVBUFFER_LENGTH(input) > 0) { int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer)); diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 143fa8395b..e802ade90c 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -330,7 +330,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { auto conn = static_cast(connection); struct evbuffer *buf = bufferevent_get_input(bev); - char read_buffer[4096]; + char read_buffer[kMessageChunkLength]; while (EVBUFFER_LENGTH(buf) > 0) { int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)); if (read == -1) { diff --git a/tests/ut/cpp/ps/core/http_client_test.cc b/tests/ut/cpp/ps/core/http_client_test.cc new file mode 100644 index 0000000000..ed50abe2f6 --- /dev/null +++ b/tests/ut/cpp/ps/core/http_client_test.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/common_test.h" +#include "ps/core/http_server.h" +#include "ps/core/http_client.h" + +using namespace std; + +namespace mindspore { +namespace ps { +namespace core { +class TestHttpClient : public UT::Common { + public: + TestHttpClient() : server_(nullptr), http_server_thread_(nullptr) {} + + virtual ~TestHttpClient() = default; + + OnRequestReceive http_get_func = std::bind( + [](std::shared_ptr resp) { + EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget"); + const unsigned char ret[] = "get request success!\n"; + resp->QuickResponse(200, ret, 22); + }, + std::placeholders::_1); + + OnRequestReceive http_handler_func = std::bind( + [](std::shared_ptr resp) { + std::string host = resp->GetRequestHost(); + EXPECT_STREQ(host.c_str(), "127.0.0.1"); + + std::string path_param = resp->GetPathParam("key1"); + std::string header_param = resp->GetHeadParam("headerKey"); + unsigned char *data = nullptr; + const uint64_t len = resp->GetPostMsg(&data); + char post_message[len + 1]; + if (memset_s(post_message, len + 1, 0, len + 1) != 0) { + MS_LOG(EXCEPTION) << "The memset_s error"; + } + if (memcpy_s(post_message, len, data, len) != 0) { + MS_LOG(EXCEPTION) << "The memset_s error"; + } + EXPECT_STREQ(path_param.c_str(), "value1"); + EXPECT_STREQ(header_param.c_str(), "headerValue"); + EXPECT_STREQ(post_message, "postKey=postValue"); + + const std::string rKey("headKey"); + const std::string rVal("headValue"); + const std::string rBody("post request success!\n"); + resp->AddRespHeadParam(rKey, rVal); + resp->AddRespString(rBody); + + resp->SetRespCode(200); + resp->SendResponse(); + }, + std::placeholders::_1); + + void SetUp() override { + server_ = std::make_unique("0.0.0.0", 9999); + + server_->RegisterRoute("/httpget", &http_get_func); + server_->RegisterRoute("/handler", &http_handler_func); + http_server_thread_ = std::make_unique([&]() { server_->Start(); }); + http_server_thread_->detach(); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + + void TearDown() override { + server_->Stop(); + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + } + + private: + std::unique_ptr server_; + std::unique_ptr http_server_thread_; +}; + +TEST_F(TestHttpClient, Get) { + HttpClient client; + std::map headers = {{"headerKey", "headerValue"}}; + auto output = std::make_shared>(); + auto ret = client.Get("http://127.0.0.1:9999/httpget", output, headers); + EXPECT_STREQ("get request success!\n", output->data()); + EXPECT_EQ(Status::OK, ret); +} + +TEST_F(TestHttpClient, Post) { + HttpClient client; + std::map headers = {{"headerKey", "headerValue"}}; + auto output = std::make_shared>(); + std::string post_data = "postKey=postValue"; + auto ret = + client.Post("http://127.0.0.1:9999/handler?key1=value1", post_data.c_str(), post_data.length(), output, headers); + EXPECT_STREQ("post request success!\n", output->data()); + EXPECT_EQ(Status::OK, ret); +} +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/tests/ut/cpp/ps/core/http_server_test.cc b/tests/ut/cpp/ps/core/http_server_test.cc index 01445eabe4..e7e1e1fa63 100644 --- a/tests/ut/cpp/ps/core/http_server_test.cc +++ b/tests/ut/cpp/ps/core/http_server_test.cc @@ -42,7 +42,6 @@ class TestHttpServer : public UT::Common { std::string path_param = resp->GetPathParam("key1"); std::string header_param = resp->GetHeadParam("headerKey"); - std::string post_param = resp->GetPostParam("postKey"); unsigned char *data = nullptr; const uint64_t len = resp->GetPostMsg(&data); char post_message[len + 1]; @@ -54,7 +53,6 @@ class TestHttpServer : public UT::Common { } EXPECT_STREQ(path_param.c_str(), "value1"); EXPECT_STREQ(header_param.c_str(), "headerValue"); - EXPECT_STREQ(post_param.c_str(), "postValue"); EXPECT_STREQ(post_message, "postKey=postValue"); const std::string rKey("headKey");