From ee4132889e81dabd5fc42c7c990012847ca7751d Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Tue, 8 Dec 2020 21:30:32 +0800 Subject: [PATCH] added worker node --- mindspore/ccsrc/ps/CMakeLists.txt | 1 + mindspore/ccsrc/ps/core/comm_util.cc | 15 +- mindspore/ccsrc/ps/core/comm_util.h | 2 + mindspore/ccsrc/ps/core/node.cc | 239 ++++++++++++++++-- mindspore/ccsrc/ps/core/node.h | 58 ++++- mindspore/ccsrc/ps/core/protos/comm.proto | 4 + mindspore/ccsrc/ps/core/tcp_client.cc | 2 +- .../ccsrc/ps/core/tcp_message_handler.cc | 10 +- mindspore/ccsrc/ps/core/tcp_message_handler.h | 7 +- mindspore/ccsrc/ps/core/tcp_server.cc | 2 +- mindspore/ccsrc/ps/core/worker_node.cc | 187 ++++++++++++++ mindspore/ccsrc/ps/core/worker_node.h | 69 +++++ tests/ut/cpp/ps/core/common_util_test.cc | 8 + .../cpp/ps/core/tcp_message_handler_test.cc | 72 +++--- 14 files changed, 590 insertions(+), 86 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/worker_node.cc create mode 100644 mindspore/ccsrc/ps/core/worker_node.h diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 2700a89d63..318e99635d 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -15,6 +15,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") endif () if (NOT ENABLE_D) diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 5fc35df074..2e1d73cecf 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -94,16 +94,16 @@ std::string CommUtil::GenerateUUID() { ss << dis(gen); } ss << "-4"; - for (i = 0; i < kGroup2RandomLength - 1; i++) { + for (i = 0; i < kGroup3RandomLength - 1; i++) { ss << dis(gen); } ss << "-"; ss << dis2(gen); - for (i = 0; i < kGroup3RandomLength - 1; i++) { + for (i = 0; i < kGroup4RandomLength - 1; i++) { ss << dis(gen); } ss << "-"; - for (i = 0; i < kGroup4RandomLength; i++) { + for (i = 0; i < kGroup5RandomLength; i++) { ss << dis(gen); } return ss.str(); @@ -121,7 +121,14 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) { MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; } } - +bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) { + if (node_role == NodeRole::SERVER && (rank_id > ClusterConfig::server_num() - 1)) { + return false; + } else if (node_role == NodeRole::WORKER && (rank_id > ClusterConfig::worker_num() - 1)) { + return false; + } + return true; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index d48ef49891..13ed85db82 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -48,6 +48,7 @@ #include "proto/comm.pb.h" #include "proto/ps.pb.h" +#include "ps/core/cluster_config.h" #include "utils/log_adapter.h" namespace mindspore { @@ -66,6 +67,7 @@ class CommUtil { static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); static std::string GenerateUUID(); static std::string NodeRoleToString(const NodeRole &role); + static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id); private: static std::random_device rd; diff --git a/mindspore/ccsrc/ps/core/node.cc b/mindspore/ccsrc/ps/core/node.cc index bbca86d302..500afb11d2 100644 --- a/mindspore/ccsrc/ps/core/node.cc +++ b/mindspore/ccsrc/ps/core/node.cc @@ -47,13 +47,17 @@ void Node::ProcessHeartbeatResp(const CommMessage &message) { is_ready_ = heartbeat_resp_message.is_cluster_ready(); if (is_ready_.load()) { wait_start_cond_.notify_all(); + MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; } is_finish_ = heartbeat_resp_message.is_cluster_finish(); if (is_finish_.load()) { wait_finish_cond_.notify_all(); + MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; } is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); if (is_timeout_ && on_node_event_message_) { + is_ready_ = true; + wait_start_cond_.notify_all(); on_node_event_message_(NodeEvent::NODE_TIMEOUT); } } @@ -64,7 +68,9 @@ void Node::FetchServers(const std::shared_ptr &client) { CommMessage message; *message.mutable_pb_meta() = {meta}; - SendMessageSync(client, message); + if (!SendMessageSync(client, message)) { + MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; + } } void Node::ProcessFetchServersResp(const CommMessage &message) { @@ -72,10 +78,10 @@ void Node::ProcessFetchServersResp(const CommMessage &message) { fetch_servers_resp_message.ParseFromString(message.data()); for (const auto &it : fetch_servers_resp_message.servers_meta()) { - server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port()); + nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); } - MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size(); + MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size(); } std::string Node::node_id() const { return node_info_.node_id_; } @@ -86,19 +92,128 @@ void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { on_node_event_message_ = on_node_event_message; } -void Node::Wait(uint64_t request_id) { - std::unique_lock lock(message_mutex_); - message_tracker_cond_.wait(lock, [&] { +bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { + std::unique_lock lock(message_tracker_mutex_); + bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; - if (ret) { - MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id; - message_tracker_.erase(request_id); - } return ret; }); + message_tracker_.erase(request_id); + return res; +} + +bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + const uint32_t &timeout) { + if (!CommUtil::ValidateRankId(node_role, rank_id)) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(message); + auto client = GetOrCreateTcpClient(rank_id); + return SendMessageSync(client, comm_message); } -void Node::Disconnect(const std::shared_ptr &client) { +bool Node::Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + const uint32_t &timeout) { + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(data.size(), 0); + + if (rank_ids.size() != data.size()) { + MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; + } + for (size_t it = 0; it < rank_ids.size(); ++it) { + if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(data.at(it)); + + auto client = GetOrCreateTcpClient(rank_ids.at(it)); + client->SendMessage(comm_message); + } + return Wait(request_id, timeout); +} + +bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + CommMessage *comm_message_resp, const uint32_t &timeout) { + if (!CommUtil::ValidateRankId(node_role, rank_id)) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(1, 0); + set_message_callback(request_id, [&]() { + receive_messages_mutex_.lock(); + auto res = receive_messages_[request_id]; + comm_message_resp = &res[rank_id]; + receive_messages_.erase(request_id); + receive_messages_mutex_.unlock(); + }); + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(message); + auto client = GetOrCreateTcpClient(rank_id); + client->SendMessage(comm_message); + return Wait(request_id, timeout); +} + +bool Node::Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + std::vector *comm_message_resp, const uint32_t &timeout) { + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(data.size(), 0); + + if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { + MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; + } + + size_t len = rank_ids.size(); + + set_message_callback(request_id, [&]() { + receive_messages_mutex_.lock(); + auto res = receive_messages_[request_id]; + for (size_t it = 0; it < len; ++it) { + comm_message_resp->at(it) = &res[rank_ids.at(it)]; + } + receive_messages_.erase(request_id); + receive_messages_mutex_.unlock(); + }); + + for (size_t it = 0; it < len; ++it) { + if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(data.at(it)); + + auto client = GetOrCreateTcpClient(rank_ids.at(it)); + client->SendMessage(comm_message); + } + return Wait(request_id, timeout); +} + +bool Node::Disconnect(const std::shared_ptr &client, const uint32_t &timeout) { MessageMeta meta; meta.set_cmd(NodeCommand::FINISH); @@ -108,36 +223,43 @@ void Node::Disconnect(const std::shared_ptr &client) { CommMessage message; *message.mutable_pb_meta() = {meta}; message.set_data(finish_message.SerializeAsString()); - SendMessageSync(client, message); - WaitForDisconnect(); + if (!SendMessageSync(client, message)) { + MS_LOG(EXCEPTION) << "Disconnect timeout!"; + } + MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; + return WaitForDisconnect(timeout); } -void Node::WaitForStart() { +bool Node::WaitForStart(const uint32_t &timeout) { std::unique_lock lock(wait_start_mutex_); - wait_start_cond_.wait(lock, [&] { - if (is_ready_.load()) { - MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!"; + bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { + bool res = is_ready_.load(); + if (res) { + MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!"; } - return is_ready_.load(); + return res; }); + return res; } -void Node::WaitForDisconnect() { +bool Node::WaitForDisconnect(const uint32_t &timeout) { std::unique_lock lock(wait_finish_mutex_); - wait_finish_cond_.wait(lock, [&] { + bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { if (is_finish_.load()) { - MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!"; + MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; } return is_finish_.load(); }); + return res; } -void Node::SendMessageSync(const std::shared_ptr &client, const CommMessage &message) { +bool Node::SendMessageSync(const std::shared_ptr &client, const CommMessage &message, + const uint32_t &timeout) { uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(1, 0); const_cast(message).mutable_pb_meta()->set_request_id(request_id); client->SendMessage(message); - Wait(request_id); + return Wait(request_id, timeout); } void Node::SendMessageAsync(const std::shared_ptr &client, const CommMessage &message) { @@ -147,12 +269,83 @@ void Node::SendMessageAsync(const std::shared_ptr &client, const Comm } void Node::NotifyMessageArrival(const CommMessage &message) { + std::lock_guard lock(message_tracker_mutex_); const MessageMeta &message_meta = message.pb_meta(); uint64_t request_id = message_meta.request_id(); message_tracker_[request_id].second++; message_tracker_cond_.notify_all(); } + +const std::shared_ptr &Node::GetOrCreateTcpClient(const int &rank_id) { + std::lock_guard lock(client_mutex_); + if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { + return connected_nodes_[rank_id]; + } else { + if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { + MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!"; + } + std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; + uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; + auto client = std::make_shared(ip, port); + client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { + switch (message.pb_meta().cmd()) { + case NodeCommand::SEND_DATA: + ProcessSendDataResp(message); + break; + default: + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + NotifyMessageArrival(message); + }); + client->Init(); + connected_nodes_[rank_id] = client; + return connected_nodes_[rank_id]; + } +} + +void Node::ProcessSendDataResp(const CommMessage &message) { + std::lock_guard lock(receive_messages_mutex_); + const MessageMeta &message_meta = message.pb_meta(); + const uint32_t &rank_id = message_meta.rank_id(); + const uint64_t request_id = message_meta.request_id(); + auto it = receive_messages_.find(request_id); + if (it != receive_messages_.end()) { + it->second.insert(std::make_pair(rank_id, message)); + } else { + std::unordered_map res; + res.insert(std::make_pair(rank_id, message)); + receive_messages_[request_id] = res; + } + + RunMessageCallback(request_id); +} + +void Node::RunMessageCallback(const uint64_t &request_id) { + message_callbacks_mutex_.lock(); + if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) { + auto it = message_callbacks_.find(request_id); + if (it != message_callbacks_.end()) { + message_callbacks_mutex_.unlock(); + + if (it->second) { + it->second(); + } + + message_callbacks_mutex_.lock(); + message_callbacks_.erase(it); + } + } + message_callbacks_mutex_.unlock(); +} + +void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { + if (!message_callback) { + return; + } + std::lock_guard lock(message_callbacks_mutex_); + message_callbacks_[request_id] = message_callback; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index 5ff490f008..cecebbb229 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -21,15 +21,15 @@ #include #include #include -#include #include -#include #include #include #include #include #include #include +#include +#include #include "proto/comm.pb.h" #include "proto/ps.pb.h" @@ -42,6 +42,8 @@ namespace mindspore { namespace ps { namespace core { +constexpr int kTimeoutInSeconds = 30; +constexpr int kCommTimeoutInSeconds = 3; class Node { public: Node() @@ -49,51 +51,83 @@ class Node { is_finish_(false), is_timeout_(false), is_already_stopped_(true), + is_already_finished_(false), next_request_id_(0), heart_beat_thread_(nullptr) {} virtual ~Node() = default; using OnNodeEventMessage = std::function; - void set_callback(const OnNodeEventMessage &on_node_event_message); + using MessageCallback = std::function; + virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0; + virtual bool Stop() = 0; + virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; + + void set_callback(const OnNodeEventMessage &on_node_event_message); std::string node_id() const; uint32_t rank_id() const; + bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); - void Wait(uint64_t request_id); + virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + const uint32_t &timeout = kCommTimeoutInSeconds); + virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, + const std::vector &data, const uint32_t &timeout = kCommTimeoutInSeconds); + virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); + virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, + const std::vector &data, std::vector *comm_message_resp, + const uint32_t &timeout = kCommTimeoutInSeconds); protected: void Heartbeat(const std::shared_ptr &client); void ProcessHeartbeatResp(const CommMessage &message); void FetchServers(const std::shared_ptr &client); void ProcessFetchServersResp(const CommMessage &message); - void Disconnect(const std::shared_ptr &client); - void WaitForStart(); - void WaitForDisconnect(); - void SendMessageSync(const std::shared_ptr &client, const CommMessage &message); + bool Disconnect(const std::shared_ptr &client, const uint32_t &timeout); + bool WaitForStart(const uint32_t &timeout); + bool WaitForDisconnect(const uint32_t &timeout); + bool SendMessageSync(const std::shared_ptr &client, const CommMessage &message, + const uint32_t &timeout = kCommTimeoutInSeconds); void SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); void NotifyMessageArrival(const CommMessage &message); + const std::shared_ptr &GetOrCreateTcpClient(const int &rank_id); + void ProcessSendDataResp(const CommMessage &message); + void RunMessageCallback(const uint64_t &request_id); + void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); NodeInfo node_info_; std::atomic is_ready_; std::atomic is_finish_; std::atomic is_timeout_; std::atomic is_already_stopped_; + std::atomic is_already_finished_; std::atomic_uint64_t next_request_id_; std::unique_ptr heart_beat_thread_; OnNodeEventMessage on_node_event_message_; - // rank_id-> - std::unordered_map> server_rank_ids_; + // -> + std::map, std::pair> nodes_address_; + // rank_id->tcpclient + std::unordered_map> connected_nodes_; - // timestamp-> + // request_id-> std::unordered_map> message_tracker_; - std::mutex message_mutex_; + std::mutex message_tracker_mutex_; std::condition_variable message_tracker_cond_; std::mutex wait_finish_mutex_; std::condition_variable wait_finish_cond_; std::mutex wait_start_mutex_; std::condition_variable wait_start_cond_; + std::mutex finish_mutex_; + std::mutex client_mutex_; + + // request_id -> + std::unordered_map> receive_messages_; + std::mutex receive_messages_mutex_; + // request_id -> MessageCallback + std::unordered_map message_callbacks_; + std::mutex message_callbacks_mutex_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index f45fe583f5..4e47afeed0 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -39,6 +39,10 @@ message MessageMeta { NodeCommand cmd = 1; // the request id of this message uint64 request_id = 2; + // the role of the current node: worker,server,scheduler + NodeRole role = 3; + // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] + int32 rank_id = 4; } message RegisterMessage { diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index fcc0ec44b9..8868daa60e 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -249,7 +249,7 @@ void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb void TcpClient::SendMessage(const CommMessage &message) const { MS_EXCEPTION_IF_NULL(buffer_event_); - uint32_t buf_size = message.ByteSizeLong(); + size_t buf_size = message.ByteSizeLong(); std::vector serialized(buf_size); message.SerializeToArray(serialized.data(), static_cast(buf_size)); if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { diff --git a/mindspore/ccsrc/ps/core/tcp_message_handler.cc b/mindspore/ccsrc/ps/core/tcp_message_handler.cc index a98b1352a2..c64b36a306 100644 --- a/mindspore/ccsrc/ps/core/tcp_message_handler.cc +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.cc @@ -23,7 +23,6 @@ namespace mindspore { namespace ps { namespace core { - void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { @@ -32,11 +31,11 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { while (num > 0) { if (remaining_length_ == 0) { - for (int i = 0; i < 4 && num > 0; ++i) { + for (int i = 0; i < kHeaderLen && num > 0; ++i) { header_[++header_index_] = *(buffer_data + i); --num; - if (header_index_ == 3) { - message_length_ = *reinterpret_cast(header_); + if (header_index_ == kHeaderLen - 1) { + message_length_ = *reinterpret_cast(header_); remaining_length_ = message_length_; message_buffer_.reset(new unsigned char[remaining_length_]); buffer_data += (i + 1); @@ -46,7 +45,7 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { } if (remaining_length_ > 0 && num > 0) { - uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; + size_t copy_len = remaining_length_ <= num ? remaining_length_ : num; remaining_length_ -= copy_len; num -= copy_len; @@ -71,7 +70,6 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { } } } - } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/tcp_message_handler.h b/mindspore/ccsrc/ps/core/tcp_message_handler.h index b77041e65e..b728d8a3fc 100644 --- a/mindspore/ccsrc/ps/core/tcp_message_handler.h +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.h @@ -31,6 +31,7 @@ namespace mindspore { namespace ps { namespace core { using messageReceive = std::function; +constexpr int kHeaderLen = 8; class TcpMessageHandler { public: @@ -51,10 +52,10 @@ class TcpMessageHandler { bool is_parsed_; std::unique_ptr message_buffer_; size_t message_length_; - uint32_t remaining_length_; - char header_[4]; + size_t remaining_length_; + char header_[8]; int header_index_; - uint32_t last_copy_len_; + size_t last_copy_len_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 1276ea63b0..cefc344f2e 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -55,7 +55,7 @@ const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } void TcpConnection::SendMessage(const CommMessage &message) const { MS_EXCEPTION_IF_NULL(buffer_event_); - uint32_t buf_size = message.ByteSizeLong(); + size_t buf_size = message.ByteSizeLong(); std::vector serialized(buf_size); message.SerializeToArray(serialized.data(), static_cast(buf_size)); if (evbuffer_add(bufferevent_get_output(const_cast(buffer_event_)), &buf_size, diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc new file mode 100644 index 0000000000..eb38475748 --- /dev/null +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/worker_node.h" + +namespace mindspore { +namespace ps { +namespace core { +WorkerNode::~WorkerNode() { + MS_LOG(INFO) << "Stop worker node!"; + if (!is_already_stopped_.load()) { + is_ready_ = true; + is_timeout_ = true; + client_to_scheduler_->Stop(); + if (!connected_nodes_.empty()) { + for (auto &connected_node : connected_nodes_) { + connected_node.second->Stop(); + } + } + client_to_scheduler_->StopEventBase(); + if (worker_thread_->joinable()) { + worker_thread_->join(); + } + if (heart_beat_thread_->joinable()) { + heart_beat_thread_->join(); + } + is_already_stopped_ = true; + } +} +bool WorkerNode::Start(const uint32_t &timeout) { + MS_LOG(INFO) << "Starting worker node!"; + Initialize(); + Register(); + Heartbeat(client_to_scheduler_); + + if (!WaitForStart(timeout)) { + MS_LOG(ERROR) << "Start Worker node timeout!"; + return false; + } + MS_LOG(INFO) << "The node is ready to fetch servers!"; + + if (!is_timeout_.load()) { + FetchServers(client_to_scheduler_); + MS_LOG(INFO) << "Fetch servers successful!"; + } + MS_LOG(INFO) << "The Worker node has successfully started."; + return true; +} + +void WorkerNode::Register() { + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::REGISTER); + + RegisterMessage register_message; + register_message.set_node_id(node_info_.node_id_); + register_message.set_role(node_info_.node_role_); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(register_message.SerializeAsString()); + if (!SendMessageSync(client_to_scheduler_, comm_message)) { + MS_LOG(EXCEPTION) << "Worker node register timeout!"; + } + MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_ + << "is registering to scheduler, the request id is:" << message_meta.request_id(); +} + +void WorkerNode::ProcessRegisterResp(const CommMessage &message) { + RegisterRespMessage register_resp_message; + register_resp_message.ParseFromString(message.data()); + if (register_resp_message.node_id() != node_info_.node_id_) { + MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() + << " is not match the current node id:" << node_info_.node_id_; + } + + node_info_.rank_id_ = register_resp_message.rank_id(); + + MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; +} + +void WorkerNode::Initialize() { + is_already_stopped_ = false; + node_info_.node_id_ = CommUtil::GenerateUUID(); + node_info_.node_role_ = NodeRole::WORKER; + MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_; + InitClientToScheduler(); +} + +void WorkerNode::InitClientToScheduler() { + std::string scheduler_host = ClusterConfig::scheduler_host(); + uint16_t scheduler_port = ClusterConfig::scheduler_port(); + client_to_scheduler_ = std::make_shared(scheduler_host, scheduler_port); + client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { + switch (message.pb_meta().cmd()) { + case NodeCommand::HEARTBEAT: + ProcessHeartbeatResp(message); + break; + case NodeCommand::REGISTER: + ProcessRegisterResp(message); + break; + case NodeCommand::FETCH_SERVER: + ProcessFetchServersResp(message); + break; + case NodeCommand::FINISH: + MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; + break; + default: + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + NotifyMessageArrival(message); + }); + + client_to_scheduler_->Init(); + worker_thread_ = std::make_unique([&]() { + MS_LOG(INFO) << "The worker node start a tcp client!"; + client_to_scheduler_->Start(); + }); + worker_thread_->detach(); +} + +bool WorkerNode::Stop() { + MS_LOG(INFO) << "Stop worker node!"; + if (!is_already_stopped_.load()) { + is_ready_ = true; + is_timeout_ = true; + client_to_scheduler_->Stop(); + if (!connected_nodes_.empty()) { + for (auto &connected_node : connected_nodes_) { + connected_node.second->Stop(); + } + } + client_to_scheduler_->StopEventBase(); + if (worker_thread_->joinable()) { + worker_thread_->join(); + } + if (heart_beat_thread_->joinable()) { + heart_beat_thread_->join(); + } + is_already_stopped_ = true; + } + return true; +} + +bool WorkerNode::Finish(const uint32_t &timeout) { + std::lock_guard lock(finish_mutex_); + if (is_already_finished_) { + MS_LOG(INFO) << "Worker node already finish!"; + return true; + } + MS_LOG(INFO) << "Finish worker node!"; + is_already_finished_ = true; + return Disconnect(client_to_scheduler_, timeout); +} + +bool WorkerNode::BroadcastToServers(const std::string &message) { + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); + for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(message); + auto client = GetOrCreateTcpClient((*it).first.second); + client->SendMessage(comm_message); + } + return Wait(request_id); +} +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h new file mode 100644 index 0000000000..32f6622fa5 --- /dev/null +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ +#define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/core/cluster_config.h" +#include "ps/core/tcp_client.h" +#include "ps/core/tcp_server.h" +#include "ps/core/node.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace core { +class WorkerNode : public Node { + public: + WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {} + ~WorkerNode() override; + + bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; + bool Stop() override; + bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; + + bool BroadcastToServers(const std::string &message); + + private: + void Register(); + void ProcessRegisterResp(const CommMessage &message); + + void Initialize(); + void InitClientToScheduler(); + + std::shared_ptr client_to_scheduler_; + std::unique_ptr worker_thread_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ diff --git a/tests/ut/cpp/ps/core/common_util_test.cc b/tests/ut/cpp/ps/core/common_util_test.cc index 4b58469248..0f42acb81b 100644 --- a/tests/ut/cpp/ps/core/common_util_test.cc +++ b/tests/ut/cpp/ps/core/common_util_test.cc @@ -39,6 +39,14 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { EXPECT_TRUE(!interface.empty()); EXPECT_TRUE(!ip.empty()); } + +TEST_F(TestCommUtil, ValidateRankId) { +ClusterConfig::Init(3, 2, std::make_unique("127.0.0.1"), 9999); +EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); +EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); +EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); +EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); +} } // namespace comm } // namespace ps } // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/core/tcp_message_handler_test.cc b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc index 65bc90ae73..f1382ad5a3 100644 --- a/tests/ut/cpp/ps/core/tcp_message_handler_test.cc +++ b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc @@ -33,117 +33,118 @@ class TestTcpMessageHandler : public UT::Common { void TearDown() override {} }; -TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) { +TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { TcpMessageHandler handler; handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); std::string data(1000, 'a'); CommMessage message; message.set_data(data); - uint32_t buf_size = message.ByteSizeLong(); - char result[1007]; - int ret = memcpy_s(result, 4, &buf_size, 4); + size_t buf_size = message.ByteSizeLong(); + char result[1011]; + int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } std::vector serialized(buf_size); message.SerializeToArray(serialized.data(), static_cast(buf_size)); - memcpy_s(result + 4, buf_size, serialized.data(), buf_size); - handler.ReceiveMessage(result, buf_size + 4); + memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); + handler.ReceiveMessage(result, buf_size + kHeaderLen); } -TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) { +TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { TcpMessageHandler handler; handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); std::string data(1000, 'a'); CommMessage message; message.set_data(data); - uint32_t buf_size = message.ByteSizeLong(); - char result[2014]; - int ret = memcpy_s(result, 4, &buf_size, 4); + size_t buf_size = message.ByteSizeLong(); + char result[2022] = {0}; + int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } std::vector serialized(buf_size); message.SerializeToArray(serialized.data(), static_cast(buf_size)); - ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); + ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); + ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size); + ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - handler.ReceiveMessage(result, 2 * buf_size + 4 * 2); + handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2); } -TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) { +TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { TcpMessageHandler handler; - handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); }); + handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); }); - std::string data(4087, 'a'); + std::string data(4081, 'a'); CommMessage message; message.set_data(data); - uint32_t buf_size = message.ByteSizeLong(); - char result[4096]; - int ret = memcpy_s(result, 4, &buf_size, 4); + size_t buf_size = message.ByteSizeLong(); + char result[4096] = {0}; + int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } std::vector serialized(buf_size); message.SerializeToArray(serialized.data(), static_cast(buf_size)); - ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); + ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2); + ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } handler.ReceiveMessage(result, 4096); - ret = memcpy_s(result, 2, &buf_size + 2, 2); + auto temp = reinterpret_cast(&buf_size); + ret = memcpy_s(result, 4, temp + 4, 4); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size); + ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - handler.ReceiveMessage(result, 4092); + handler.ReceiveMessage(result, 4088); } -TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { +TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { TcpMessageHandler handler; - handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); }); + handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); }); - std::string data(4085, 'a'); + std::string data(4077, 'a'); CommMessage message; message.set_data(data); - uint32_t buf_size = message.ByteSizeLong(); - char result[4096]; - int ret = memcpy_s(result, 4, &buf_size, 4); + size_t buf_size = message.ByteSizeLong(); + char result[4096] = {0}; + int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } std::vector serialized(buf_size); message.SerializeToArray(serialized.data(), static_cast(buf_size)); - ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); + ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); + ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } @@ -155,9 +156,8 @@ TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - handler.ReceiveMessage(result, 4088); + handler.ReceiveMessage(result, 4080); } - -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore \ No newline at end of file