From 24345595f3565e7c9d7b008b0df5fec4b05cb8c3 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Wed, 23 Dec 2020 16:40:58 +0800 Subject: [PATCH] added collective send and receive --- mindspore/ccsrc/ps/core/abstract_node.cc | 152 +++++++++++++++++++-- mindspore/ccsrc/ps/core/abstract_node.h | 56 +++++--- mindspore/ccsrc/ps/core/protos/comm.proto | 1 + mindspore/ccsrc/ps/core/scheduler_node.cc | 23 +--- mindspore/ccsrc/ps/core/scheduler_node.h | 1 + mindspore/ccsrc/ps/core/server_node.cc | 50 ++++--- mindspore/ccsrc/ps/core/server_node.h | 5 +- mindspore/ccsrc/ps/core/tcp_client.cc | 51 ++++--- mindspore/ccsrc/ps/core/tcp_client.h | 2 +- mindspore/ccsrc/ps/core/tcp_server.cc | 54 ++++---- mindspore/ccsrc/ps/core/tcp_server.h | 2 +- mindspore/ccsrc/ps/core/worker_node.cc | 29 +--- tests/ut/cpp/ps/core/abstract_node_test.cc | 42 ++++++ 13 files changed, 316 insertions(+), 152 deletions(-) create mode 100644 tests/ut/cpp/ps/core/abstract_node_test.cc diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index d7d973efc5..52dd4f48e0 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -53,13 +53,20 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; } -bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) { +bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) { + if (node_role != NodeRole::SERVER) { + MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; + } + uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); + for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { MessageMeta message_meta; message_meta.set_cmd(NodeCommand::SEND_DATA); message_meta.set_request_id(request_id); + message_meta.set_rank_id(node_info_.rank_id_); + message_meta.set_role(node_info_.node_role_); CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; @@ -82,12 +89,14 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, MessageMeta message_meta; message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_rank_id(node_info_.rank_id_); + message_meta.set_role(node_info_.node_role_); CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); - return SendMessageSync(client, comm_message); + return SendMessageSync(client, comm_message, timeout); } bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, @@ -106,6 +115,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & MessageMeta message_meta; message_meta.set_cmd(NodeCommand::SEND_DATA); message_meta.set_request_id(request_id); + message_meta.set_rank_id(node_info_.rank_id_); + message_meta.set_role(node_info_.node_role_); CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; @@ -118,8 +129,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & } bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - CommMessage *comm_message_resp, const uint32_t &timeout) { - MS_EXCEPTION_IF_NULL(comm_message_resp); + CommMessage *output, const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(output); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } @@ -129,7 +140,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; - *comm_message_resp = res[rank_id]; + *output = res[rank_id]; receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -149,9 +160,9 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, } bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, std::vector *comm_message_resp, + const std::vector &data, std::vector *output, const uint32_t &timeout) { - MS_EXCEPTION_IF_NULL(comm_message_resp); + MS_EXCEPTION_IF_NULL(output); uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(data.size(), 0); @@ -165,7 +176,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; for (size_t it = 0; it < len; ++it) { - (*comm_message_resp).push_back(res[rank_ids.at(it)]); + (*output).push_back(res[rank_ids.at(it)]); } receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); @@ -179,6 +190,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & MessageMeta message_meta; message_meta.set_cmd(NodeCommand::SEND_DATA); message_meta.set_request_id(request_id); + message_meta.set_rank_id(node_info_.rank_id_); + message_meta.set_role(node_info_.node_role_); CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; @@ -200,6 +213,58 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { return res; } +uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, + const std::string &message, const uint32_t &timeout) { + if (!CommUtil::ValidateRankId(node_role, rank_id)) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); + message_meta.set_rank_id(node_info_.rank_id_); + message_meta.set_role(node_info_.node_role_); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(message); + auto client = GetOrCreateTcpClient(rank_id); + return SendMessageAsync(client, comm_message); +} + +std::pair AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, + const uint32_t &rank_id, CommMessage *output) { + if (!CommUtil::ValidateRankId(node_role, rank_id)) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); + if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { + *output = received_data_[std::make_pair(rank_id, rank_request_id)]; + received_data_.erase(std::make_pair(rank_id, rank_request_id)); + } else { + set_receive_callback(rank_id, rank_request_id, [=]() { + receive_callbacks_mutex_.lock(); + *output = received_data_[std::make_pair(rank_id, 1)]; + received_data_.erase(std::make_pair(rank_id, rank_request_id)); + receive_callbacks_mutex_.unlock(); + }); + } + return std::make_pair(rank_id, rank_request_id); +} + +bool AbstractNode::CollectiveWait(std::pair request_id, const uint32_t &timeout) { + std::unique_lock lock(receive_callbacks_mutex_); + bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { + if (actual_rank_request_ids_.count(request_id.first) && + (actual_rank_request_ids_[request_id.first] >= request_id.second)) { + return true; + } else { + return false; + } + }); + return res; +} + void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) { MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ @@ -210,7 +275,6 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); } }); - heart_beat_thread_->detach(); } void AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_node_finish) { @@ -334,11 +398,9 @@ bool AbstractNode::InitClientToScheduler() { MS_LOG(INFO) << "The worker node start a tcp client!"; client_to_scheduler_->Start(); }); - client_to_scheduler_thread_->detach(); client_to_scheduler_->set_disconnected_callback([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval())); - client_to_scheduler_->Stop(); client_to_scheduler_->Init(); }); return client_to_scheduler_->WaitConnected(); @@ -361,6 +423,9 @@ const std::shared_ptr &AbstractNode::GetOrCreateTcpClient(const int & ProcessSendDataResp(message); RunMessageCallback(message.pb_meta().request_id()); break; + case NodeCommand::COLLECTIVE_SEND_DATA: + MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; + break; default: MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; } @@ -381,10 +446,12 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr &client, con return Wait(request_id, timeout); } -void AbstractNode::SendMessageAsync(const std::shared_ptr &client, const CommMessage &message) { +uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr &client, const CommMessage &message) { uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(1, 0); const_cast(message).mutable_pb_meta()->set_request_id(request_id); client->SendMessage(message); + return request_id; } void AbstractNode::ProcessSendDataResp(const CommMessage &message) { @@ -422,12 +489,12 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) { message_callbacks_mutex_.unlock(); } -void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { - if (!message_callback) { +void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) { + if (!callback) { return; } std::lock_guard lock(message_callbacks_mutex_); - message_callbacks_[request_id] = message_callback; + message_callbacks_[request_id] = callback; } void AbstractNode::NotifyMessageArrival(const CommMessage &message) { @@ -438,6 +505,61 @@ void AbstractNode::NotifyMessageArrival(const CommMessage &message) { message_tracker_[request_id].second++; message_tracker_cond_.notify_all(); } + +void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, + const MessageCallback &callback) { + if (!callback) { + return; + } + std::lock_guard lock(receive_callbacks_mutex_); + receive_callbacks_[std::make_pair(rank_id, request_id)] = callback; +} + +void AbstractNode::RunReceiveCallback(const CommMessage &message) { + receive_callbacks_mutex_.lock(); + uint32_t rank_id = message.pb_meta().rank_id(); + // When receiving a collective message, Then generate rank request id,compare with the desired rank request id, + // If they are equal, then call the callback function + uint64_t rank_request_id = NextActualRankRequestId(rank_id); + received_data_[std::make_pair(rank_id, rank_request_id)] = message; + auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id)); + if (it != receive_callbacks_.end()) { + receive_callbacks_mutex_.unlock(); + + if (it->second) { + it->second(); + } + + receive_callbacks_mutex_.lock(); + receive_cond_.notify_all(); + receive_callbacks_.erase(it); + } + receive_callbacks_mutex_.unlock(); +} + +uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) { + std::lock_guard lock(rank_request_ids_mutex); + uint64_t rank_request_id = 1; + if (expected_rank_request_ids_.count(rank_id)) { + rank_request_id = ++expected_rank_request_ids_[rank_id]; + expected_rank_request_ids_[rank_id] = rank_request_id; + } else { + expected_rank_request_ids_[rank_id] = rank_request_id; + } + return rank_request_id; +} + +uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) { + std::lock_guard lock(rank_request_ids_mutex); + uint64_t rank_request_id = 1; + if (actual_rank_request_ids_.count(rank_id)) { + rank_request_id = ++actual_rank_request_ids_[rank_id]; + actual_rank_request_ids_[rank_id] = rank_request_id; + } else { + actual_rank_request_ids_[rank_id] = rank_request_id; + } + return rank_request_id; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index 6b2fe52153..dff77346e1 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -34,21 +34,26 @@ class AbstractNode : public Node { AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} ~AbstractNode() override = default; - bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); + bool Broadcast(const enum NodeRole &node_role, const std::string &message, + const uint32_t &timeout = kCommTimeoutInSeconds); void set_event_callback(const OnNodeEventMessage &on_node_event_message); - virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - const uint32_t &timeout = kCommTimeoutInSeconds); - virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, const uint32_t &timeout = kCommTimeoutInSeconds); - virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); - virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, std::vector *comm_message_resp, - const uint32_t &timeout = kCommTimeoutInSeconds); - + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + const uint32_t &timeout = kCommTimeoutInSeconds); + bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + const uint32_t &timeout = kCommTimeoutInSeconds); + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output, + const uint32_t &timeout = kCommTimeoutInSeconds); + bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + std::vector *output, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); + uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + const uint32_t &timeout = kCommTimeoutInSeconds); + std::pair CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, + CommMessage *output); + bool CollectiveWait(std::pair request_id, const uint32_t &timeout = kCommTimeoutInSeconds); + protected: void Register(const std::shared_ptr &client); void ProcessRegisterResp(const CommMessage &message); @@ -63,34 +68,51 @@ class AbstractNode : public Node { const std::shared_ptr &GetOrCreateTcpClient(const int &rank_id); bool SendMessageSync(const std::shared_ptr &client, const CommMessage &message, const uint32_t &timeout = kCommTimeoutInSeconds); - void SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); + uint64_t SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); void ProcessSendDataResp(const CommMessage &message); void RunMessageCallback(const uint64_t &request_id); - void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); + void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); void NotifyMessageArrival(const CommMessage &message); + void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback); + void RunReceiveCallback(const CommMessage &message); + uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); + uint64_t NextActualRankRequestId(const uint32_t &rank_id); std::unique_ptr heart_beat_thread_; std::unique_ptr client_to_scheduler_thread_; std::shared_ptr client_to_scheduler_; OnNodeEventMessage on_node_event_message_; - // the map's key is: , the map's value is: + // the key is: , the value is: std::map, std::pair> nodes_address_; std::mutex client_mutex_; // the map's key is: rank_id std::unordered_map> connected_nodes_; - // the map's key is: request_id, the map's value is: + // the key is: request_id, the value is: std::unordered_map> message_tracker_; std::mutex message_tracker_mutex_; std::condition_variable message_tracker_cond_; - // the map's key is: request_id, the map's value is: + // the key is: request_id, the value is: std::unordered_map> receive_messages_; std::mutex receive_messages_mutex_; - // the map's key is: request_id + // the key is: request_id std::unordered_map message_callbacks_; std::mutex message_callbacks_mutex_; + + // the key is + std::map, CommMessage> received_data_; + std::mutex receive_callbacks_mutex_; + // the key is + std::map, MessageCallback> receive_callbacks_; + std::condition_variable receive_cond_; + + // the key is rank_id, the value is rank_id's expected request_id + std::unordered_map expected_rank_request_ids_; + // the key is rank_id, the value is rank_id's actual request_id + std::unordered_map actual_rank_request_ids_; + std::mutex rank_request_ids_mutex; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 0d4aa67c59..4e24de8c58 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -26,6 +26,7 @@ enum NodeCommand { SEND_DATA = 3; FETCH_SERVER = 4; FINISH = 5; + COLLECTIVE_SEND_DATA = 6; } enum NodeRole { diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index 270e540378..fb593b0da1 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -19,19 +19,10 @@ namespace mindspore { namespace ps { namespace core { + SchedulerNode::~SchedulerNode() { MS_LOG(INFO) << "Stop scheduler node!"; - if (!is_already_stopped_) { - is_already_stopped_ = true; - server_->Stop(); - if (scheduler_thread_->joinable()) { - scheduler_thread_->join(); - } - if (update_state_thread_->joinable()) { - update_state_thread_->join(); - } - is_ready_ = true; - } + Stop(); } bool SchedulerNode::Start(const uint32_t &timeout) { @@ -114,7 +105,6 @@ void SchedulerNode::CreateTcpServer() { MS_LOG(INFO) << "The scheduler node start a tcp server!"; server_->Start(); }); - scheduler_thread_->detach(); } void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { @@ -186,20 +176,15 @@ void SchedulerNode::StartUpdateClusterStateTimer() { } } }); - update_state_thread_->detach(); } bool SchedulerNode::Stop() { MS_LOG(INFO) << "Stop scheduler node!"; if (!is_already_stopped_) { is_already_stopped_ = true; + update_state_thread_->join(); server_->Stop(); - if (scheduler_thread_->joinable()) { - scheduler_thread_->join(); - } - if (update_state_thread_->joinable()) { - update_state_thread_->join(); - } + scheduler_thread_->join(); is_ready_ = true; } return true; diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index f7fe022a80..86488ea9ac 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -38,6 +38,7 @@ namespace mindspore { namespace ps { namespace core { + class SchedulerNode : public Node { public: SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 2ac8861b24..8a76a5b7d4 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -20,18 +20,7 @@ namespace ps { namespace core { ServerNode::~ServerNode() { MS_LOG(INFO) << "Stop server node!"; - if (!is_already_stopped_.load()) { - server_->Stop(); - client_to_scheduler_->Stop(); - client_to_scheduler_->StopEventBase(); - if (server_thread_->joinable()) { - server_thread_->join(); - } - if (client_to_scheduler_thread_->joinable()) { - client_to_scheduler_thread_->join(); - } - is_already_stopped_ = true; - } + Stop(); } bool ServerNode::Start(const uint32_t &timeout) { @@ -78,6 +67,10 @@ void ServerNode::CreateTcpServer() { case NodeCommand::SEND_DATA: ProcessSendData(server, conn, message); break; + case NodeCommand::COLLECTIVE_SEND_DATA: + ProcessCollectiveSendData(server, conn, message); + RunReceiveCallback(message); + break; default: MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; } @@ -87,7 +80,6 @@ void ServerNode::CreateTcpServer() { MS_LOG(INFO) << "The server node start a tcp server!"; server_->Start(); }); - server_thread_->detach(); } void ServerNode::Initialize() { @@ -106,27 +98,31 @@ void ServerNode::Initialize() { } void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { - if (request_handler_) { - request_handler_(server, conn, message.pb_meta(), message.data()); - } + request_handler_(server, conn, message.pb_meta(), message.data()); +} + +void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, + const CommMessage &message) { + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message.pb_meta()}; + const_cast(server).SendMessage(conn, comm_message); } bool ServerNode::Stop() { MS_LOG(INFO) << "Stop server node!"; if (!is_already_stopped_.load()) { - server_->Stop(); + is_already_stopped_ = true; + is_finish_ = true; + heart_beat_thread_->join(); client_to_scheduler_->Stop(); - client_to_scheduler_->StopEventBase(); - if (server_thread_->joinable()) { - server_thread_->join(); + if (!connected_nodes_.empty()) { + for (auto &connected_node : connected_nodes_) { + connected_node.second->Stop(); + } } - if (client_to_scheduler_thread_->joinable()) { - client_to_scheduler_thread_->join(); - } - if (heart_beat_thread_->joinable()) { - heart_beat_thread_->join(); - } - is_already_stopped_ = true; + client_to_scheduler_thread_->join(); + server_->Stop(); + server_thread_->join(); } return true; } diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 2c3d728dfa..77c196902f 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -44,8 +44,8 @@ class ServerNode : public AbstractNode { bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; - using RequestHandler = std::function; + using RequestHandler = std::function; void set_handler(const RequestHandler &handler); void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, @@ -55,6 +55,7 @@ class ServerNode : public AbstractNode { void CreateTcpServer(); void Initialize(); void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); std::shared_ptr server_; std::unique_ptr server_thread_; diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index 5393daa11d..d607b819a0 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -51,7 +51,20 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) }); } -TcpClient::~TcpClient() { Stop(); } +TcpClient::~TcpClient() { + if (buffer_event_) { + bufferevent_free(buffer_event_); + buffer_event_ = nullptr; + } + if (event_timeout_) { + event_free(event_timeout_); + event_timeout_ = nullptr; + } + if (event_base_) { + event_base_free(event_base_); + event_base_ = nullptr; + } +} std::string TcpClient::GetServerAddress() const { return server_address_; } @@ -69,9 +82,9 @@ bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { void TcpClient::Init() { std::lock_guard lock(connection_mutex_); if (buffer_event_) { - return; + bufferevent_free(buffer_event_); + buffer_event_ = nullptr; } - is_stop_ = false; if (!CommUtil::CheckIp(server_address_)) { MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; } @@ -82,8 +95,9 @@ void TcpClient::Init() { } if (event_base_ == nullptr) { event_base_ = event_base_new(); + MS_EXCEPTION_IF_NULL(event_base_); + is_stop_ = false; } - MS_EXCEPTION_IF_NULL(event_base_); sockaddr_in sin{}; if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { @@ -127,26 +141,18 @@ void TcpClient::StartWithDelay(int seconds) { void TcpClient::Stop() { std::lock_guard lock(connection_mutex_); - MS_LOG(INFO) << "Stop tcp client event buffer!"; - if (!is_stop_.load()) { - if (buffer_event_) { - bufferevent_free(buffer_event_); - buffer_event_ = nullptr; - } - - if (event_timeout_) { - event_free(event_timeout_); - event_timeout_ = nullptr; - } + MS_LOG(INFO) << "Stop tcp client!"; + if (event_base_got_break(event_base_)) { + MS_LOG(DEBUG) << "The event base has stopped!"; is_stop_ = true; + return; } -} - -void TcpClient::StopEventBase() { - MS_LOG(INFO) << "Stop tcp client event base!"; - int ret = event_base_loopbreak(event_base_); - if (ret != 0) { - MS_LOG(ERROR) << "Event base loop break failed!"; + if (!is_stop_.load()) { + is_stop_ = true; + int ret = event_base_loopbreak(event_base_); + if (ret != 0) { + MS_LOG(ERROR) << "Event base loop break failed!"; + } } } @@ -280,6 +286,7 @@ void TcpClient::StartTimer(const uint32_t &time) { void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } const event_base &TcpClient::eventbase() { return *event_base_; } + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h index 0aa8193cfe..ce682e0d57 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -58,7 +58,6 @@ class TcpClient { void Init(); void StartWithDelay(int seconds); void Stop(); - static void StopEventBase(); void Start(); void StartWithNoBlock(); void SetMessageCallback(const OnMessage &cb); @@ -97,6 +96,7 @@ class TcpClient { std::atomic is_stop_; std::atomic is_connected_; }; + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 798d2070dd..afafad7354 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -32,6 +32,7 @@ namespace mindspore { namespace ps { namespace core { + void TcpConnection::InitConnection() { tcp_message_handler_.SetCallback([&](const CommMessage &message) { OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); @@ -76,7 +77,22 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port) server_port_(port), is_stop_(true) {} -TcpServer::~TcpServer() { Stop(); } +TcpServer::~TcpServer() { + if (signal_event_ != nullptr) { + event_free(signal_event_); + signal_event_ = nullptr; + } + + if (listener_ != nullptr) { + evconnlistener_free(listener_); + listener_ = nullptr; + } + + if (base_ != nullptr) { + event_base_free(base_); + base_ = nullptr; + } +} void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, const OnAccepted &client_accept) { @@ -136,7 +152,6 @@ void TcpServer::Init() { } void TcpServer::Start() { - std::unique_lock lock(connection_mutex_); MS_LOG(INFO) << "Start tcp server!"; MS_EXCEPTION_IF_NULL(base_); int ret = event_base_dispatch(base_); @@ -148,7 +163,7 @@ void TcpServer::Start() { } void TcpServer::StartWithNoBlock() { - std::unique_lock lock(connection_mutex_); + std::lock_guard lock(connection_mutex_); MS_LOG(INFO) << "Start tcp server with no block!"; MS_EXCEPTION_IF_NULL(base_); int ret = event_base_loop(base_, EVLOOP_NONBLOCK); @@ -187,33 +202,25 @@ void TcpServer::StartTimer(const uint32_t &time) { } void TcpServer::Stop() { + std::lock_guard lock(connection_mutex_); MS_LOG(INFO) << "Stop tcp server!"; + if (event_base_got_break(base_)) { + MS_LOG(DEBUG) << "The event base has stopped!"; + is_stop_ = true; + return; + } if (!is_stop_.load()) { + is_stop_ = true; int ret = event_base_loopbreak(base_); if (ret != 0) { - MS_LOG(EXCEPTION) << "event base loop break failed!"; - } - if (signal_event_ != nullptr) { - event_free(signal_event_); - signal_event_ = nullptr; - } - - if (listener_ != nullptr) { - evconnlistener_free(listener_); - listener_ = nullptr; + MS_LOG(ERROR) << "Event base loop break failed!"; } - - if (base_ != nullptr) { - event_base_free(base_); - base_ = nullptr; - } - is_stop_ = true; } } void TcpServer::SendToAllClients(const char *data, size_t len) { MS_EXCEPTION_IF_NULL(data); - std::unique_lock lock(connection_mutex_); + std::lock_guard lock(connection_mutex_); for (auto it = connections_.begin(); it != connections_.end(); ++it) { it->second->SendMessage(data, len); } @@ -221,12 +228,12 @@ void TcpServer::SendToAllClients(const char *data, size_t len) { void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { MS_EXCEPTION_IF_NULL(connection); - std::unique_lock lock(connection_mutex_); + std::lock_guard lock(connection_mutex_); connections_.insert(std::make_pair(fd, connection)); } void TcpServer::RemoveConnection(const evutil_socket_t &fd) { - std::unique_lock lock(connection_mutex_); + std::lock_guard lock(connection_mutex_); TcpConnection *connection = const_cast(connections_.find(fd)->second); delete connection; connections_.erase(fd); @@ -352,7 +359,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } void TcpServer::SendMessage(const CommMessage &message) { - std::unique_lock lock(connection_mutex_); + std::lock_guard lock(connection_mutex_); for (auto it = connections_.begin(); it != connections_.end(); ++it) { SendMessage(*it->second, message); @@ -368,6 +375,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); } const std::map &TcpServer::Connections() const { return connections_; } void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/tcp_server.h b/mindspore/ccsrc/ps/core/tcp_server.h index 43294aa5e9..c268775bfd 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.h +++ b/mindspore/ccsrc/ps/core/tcp_server.h @@ -121,7 +121,7 @@ class TcpServer { OnConnected client_connection_; OnDisconnected client_disconnection_; OnAccepted client_accept_; - std::recursive_mutex connection_mutex_; + std::mutex connection_mutex_; OnServerReceiveMessage message_callback_; OnTimerOnce on_timer_once_callback_; OnTimer on_timer_callback_; diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index 3a2f40f92e..ee162e070b 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -21,24 +21,7 @@ namespace ps { namespace core { WorkerNode::~WorkerNode() { MS_LOG(INFO) << "Stop worker node!"; - if (!is_already_stopped_.load()) { - is_ready_ = true; - is_timeout_ = true; - client_to_scheduler_->Stop(); - if (!connected_nodes_.empty()) { - for (auto &connected_node : connected_nodes_) { - connected_node.second->Stop(); - } - } - client_to_scheduler_->StopEventBase(); - if (client_to_scheduler_thread_->joinable()) { - client_to_scheduler_thread_->join(); - } - if (heart_beat_thread_->joinable()) { - heart_beat_thread_->join(); - } - is_already_stopped_ = true; - } + Stop(); } bool WorkerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "Starting worker node!"; @@ -78,19 +61,15 @@ bool WorkerNode::Stop() { if (!is_already_stopped_.load()) { is_ready_ = true; is_timeout_ = true; + is_finish_ = true; + heart_beat_thread_->join(); client_to_scheduler_->Stop(); if (!connected_nodes_.empty()) { for (auto &connected_node : connected_nodes_) { connected_node.second->Stop(); } } - client_to_scheduler_->StopEventBase(); - if (client_to_scheduler_thread_->joinable()) { - client_to_scheduler_thread_->join(); - } - if (heart_beat_thread_->joinable()) { - heart_beat_thread_->join(); - } + client_to_scheduler_thread_->join(); is_already_stopped_ = true; } return true; diff --git a/tests/ut/cpp/ps/core/abstract_node_test.cc b/tests/ut/cpp/ps/core/abstract_node_test.cc new file mode 100644 index 0000000000..8921d43b4e --- /dev/null +++ b/tests/ut/cpp/ps/core/abstract_node_test.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#define protected public +#include "ps/core/worker_node.h" +#undef protected + +namespace mindspore { +namespace ps { +namespace core { +class TestAbstractNode : public UT::Common { + public: + TestAbstractNode() = default; + virtual ~TestAbstractNode() = default; + + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(TestAbstractNode, NextExpectedRankRequestId) { + WorkerNode workerNode; + ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(0)); + ASSERT_EQ(2, workerNode.NextExpectedRankRequestId(0)); + ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(1)); +} +} // namespace core +} // namespace ps +} // namespace mindspore \ No newline at end of file