!10393 added collective send and receive

From: @anancds
Reviewed-by: 
Signed-off-by:
pull/10393/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b96d4315dc

File diff suppressed because it is too large Load Diff

@ -34,21 +34,26 @@ class AbstractNode : public Node {
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
~AbstractNode() override = default;
bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Broadcast(const enum NodeRole &node_role, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp,
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
CommMessage *output);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
@ -63,34 +68,51 @@ class AbstractNode : public Node {
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
void ProcessSendDataResp(const CommMessage &message);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
void NotifyMessageArrival(const CommMessage &message);
void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback);
void RunReceiveCallback(const CommMessage &message);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
std::shared_ptr<TcpClient> client_to_scheduler_;
OnNodeEventMessage on_node_event_message_;
// the map's key is: <node_role,rank_id>, the map's value is: <ip, port>
// the key is: <node_role,rank_id>, the value is: <ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
std::mutex client_mutex_;
// the map's key is: rank_id
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
// the map's key is: request_id, the map's value is: <expected responses, actual responses>
// the key is: request_id, the value is: <expected responses, actual responses>
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
// the map's key is: request_id, the map's value is:<rank_id, CommMessage>
// the key is: request_id, the value is:<rank_id, CommMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
std::mutex receive_messages_mutex_;
// the map's key is: request_id
// the key is: request_id
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_;
// the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_;
std::mutex receive_callbacks_mutex_;
// the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;
std::condition_variable receive_cond_;
// the key is rank_id, the value is rank_id's expected request_id
std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_;
// the key is rank_id, the value is rank_id's actual request_id
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
std::mutex rank_request_ids_mutex;
};
} // namespace core
} // namespace ps

@ -26,6 +26,7 @@ enum NodeCommand {
SEND_DATA = 3;
FETCH_SERVER = 4;
FINISH = 5;
COLLECTIVE_SEND_DATA = 6;
}
enum NodeRole {

@ -19,19 +19,10 @@
namespace mindspore {
namespace ps {
namespace core {
SchedulerNode::~SchedulerNode() {
MS_LOG(INFO) << "Stop scheduler node!";
if (!is_already_stopped_) {
is_already_stopped_ = true;
server_->Stop();
if (scheduler_thread_->joinable()) {
scheduler_thread_->join();
}
if (update_state_thread_->joinable()) {
update_state_thread_->join();
}
is_ready_ = true;
}
Stop();
}
bool SchedulerNode::Start(const uint32_t &timeout) {
@ -114,7 +105,6 @@ void SchedulerNode::CreateTcpServer() {
MS_LOG(INFO) << "The scheduler node start a tcp server!";
server_->Start();
});
scheduler_thread_->detach();
}
void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
@ -186,20 +176,15 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
}
}
});
update_state_thread_->detach();
}
bool SchedulerNode::Stop() {
MS_LOG(INFO) << "Stop scheduler node!";
if (!is_already_stopped_) {
is_already_stopped_ = true;
update_state_thread_->join();
server_->Stop();
if (scheduler_thread_->joinable()) {
scheduler_thread_->join();
}
if (update_state_thread_->joinable()) {
update_state_thread_->join();
}
is_ready_ = true;
}
return true;

@ -38,6 +38,7 @@
namespace mindspore {
namespace ps {
namespace core {
class SchedulerNode : public Node {
public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}

@ -20,18 +20,7 @@ namespace ps {
namespace core {
ServerNode::~ServerNode() {
MS_LOG(INFO) << "Stop server node!";
if (!is_already_stopped_.load()) {
server_->Stop();
client_to_scheduler_->Stop();
client_to_scheduler_->StopEventBase();
if (server_thread_->joinable()) {
server_thread_->join();
}
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
is_already_stopped_ = true;
}
Stop();
}
bool ServerNode::Start(const uint32_t &timeout) {
@ -78,6 +67,10 @@ void ServerNode::CreateTcpServer() {
case NodeCommand::SEND_DATA:
ProcessSendData(server, conn, message);
break;
case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(server, conn, message);
RunReceiveCallback(message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
@ -87,7 +80,6 @@ void ServerNode::CreateTcpServer() {
MS_LOG(INFO) << "The server node start a tcp server!";
server_->Start();
});
server_thread_->detach();
}
void ServerNode::Initialize() {
@ -106,27 +98,31 @@ void ServerNode::Initialize() {
}
void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
if (request_handler_) {
request_handler_(server, conn, message.pb_meta(), message.data());
}
void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn,
const CommMessage &message) {
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
}
bool ServerNode::Stop() {
MS_LOG(INFO) << "Stop server node!";
if (!is_already_stopped_.load()) {
server_->Stop();
is_already_stopped_ = true;
is_finish_ = true;
heart_beat_thread_->join();
client_to_scheduler_->Stop();
client_to_scheduler_->StopEventBase();
if (server_thread_->joinable()) {
server_thread_->join();
}
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
client_to_scheduler_thread_->join();
server_->Stop();
server_thread_->join();
}
return true;
}

@ -44,8 +44,8 @@ class ServerNode : public AbstractNode {
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn,
const MessageMeta message_meta, const std::string &message)>;
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta,
const std::string &message)>;
void set_handler(const RequestHandler &handler);
void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
@ -55,6 +55,7 @@ class ServerNode : public AbstractNode {
void CreateTcpServer();
void Initialize();
void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;

@ -51,7 +51,20 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
});
}
TcpClient::~TcpClient() { Stop(); }
TcpClient::~TcpClient() {
if (buffer_event_) {
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}
if (event_timeout_) {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
if (event_base_) {
event_base_free(event_base_);
event_base_ = nullptr;
}
}
std::string TcpClient::GetServerAddress() const { return server_address_; }
@ -69,9 +82,9 @@ bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
void TcpClient::Init() {
std::lock_guard<std::mutex> lock(connection_mutex_);
if (buffer_event_) {
return;
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}
is_stop_ = false;
if (!CommUtil::CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
}
@ -82,8 +95,9 @@ void TcpClient::Init() {
}
if (event_base_ == nullptr) {
event_base_ = event_base_new();
}
MS_EXCEPTION_IF_NULL(event_base_);
is_stop_ = false;
}
sockaddr_in sin{};
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
@ -127,28 +141,20 @@ void TcpClient::StartWithDelay(int seconds) {
void TcpClient::Stop() {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Stop tcp client event buffer!";
if (!is_stop_.load()) {
if (buffer_event_) {
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}
if (event_timeout_) {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
MS_LOG(INFO) << "Stop tcp client!";
if (event_base_got_break(event_base_)) {
MS_LOG(DEBUG) << "The event base has stopped!";
is_stop_ = true;
return;
}
}
void TcpClient::StopEventBase() {
MS_LOG(INFO) << "Stop tcp client event base!";
if (!is_stop_.load()) {
is_stop_ = true;
int ret = event_base_loopbreak(event_base_);
if (ret != 0) {
MS_LOG(ERROR) << "Event base loop break failed!";
}
}
}
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
const int one = 1;
@ -280,6 +286,7 @@ void TcpClient::StartTimer(const uint32_t &time) {
void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
const event_base &TcpClient::eventbase() { return *event_base_; }
} // namespace core
} // namespace ps
} // namespace mindspore

@ -58,7 +58,6 @@ class TcpClient {
void Init();
void StartWithDelay(int seconds);
void Stop();
static void StopEventBase();
void Start();
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
@ -97,6 +96,7 @@ class TcpClient {
std::atomic<bool> is_stop_;
std::atomic<bool> is_connected_;
};
} // namespace core
} // namespace ps
} // namespace mindspore

@ -32,6 +32,7 @@
namespace mindspore {
namespace ps {
namespace core {
void TcpConnection::InitConnection() {
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
@ -76,7 +77,22 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port)
server_port_(port),
is_stop_(true) {}
TcpServer::~TcpServer() { Stop(); }
TcpServer::~TcpServer() {
if (signal_event_ != nullptr) {
event_free(signal_event_);
signal_event_ = nullptr;
}
if (listener_ != nullptr) {
evconnlistener_free(listener_);
listener_ = nullptr;
}
if (base_ != nullptr) {
event_base_free(base_);
base_ = nullptr;
}
}
void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
const OnAccepted &client_accept) {
@ -136,7 +152,6 @@ void TcpServer::Init() {
}
void TcpServer::Start() {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server!";
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_dispatch(base_);
@ -148,7 +163,7 @@ void TcpServer::Start() {
}
void TcpServer::StartWithNoBlock() {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server with no block!";
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
@ -187,33 +202,25 @@ void TcpServer::StartTimer(const uint32_t &time) {
}
void TcpServer::Stop() {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Stop tcp server!";
if (event_base_got_break(base_)) {
MS_LOG(DEBUG) << "The event base has stopped!";
is_stop_ = true;
return;
}
if (!is_stop_.load()) {
is_stop_ = true;
int ret = event_base_loopbreak(base_);
if (ret != 0) {
MS_LOG(EXCEPTION) << "event base loop break failed!";
}
if (signal_event_ != nullptr) {
event_free(signal_event_);
signal_event_ = nullptr;
}
if (listener_ != nullptr) {
evconnlistener_free(listener_);
listener_ = nullptr;
MS_LOG(ERROR) << "Event base loop break failed!";
}
if (base_ != nullptr) {
event_base_free(base_);
base_ = nullptr;
}
is_stop_ = true;
}
}
void TcpServer::SendToAllClients(const char *data, size_t len) {
MS_EXCEPTION_IF_NULL(data);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
it->second->SendMessage(data, len);
}
@ -221,12 +228,12 @@ void TcpServer::SendToAllClients(const char *data, size_t len) {
void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) {
MS_EXCEPTION_IF_NULL(connection);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
connections_.insert(std::make_pair(fd, connection));
}
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
delete connection;
connections_.erase(fd);
@ -352,7 +359,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
void TcpServer::SendMessage(const CommMessage &message) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(*it->second, message);
@ -368,6 +375,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); }
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
} // namespace core
} // namespace ps
} // namespace mindspore

@ -121,7 +121,7 @@ class TcpServer {
OnConnected client_connection_;
OnDisconnected client_disconnection_;
OnAccepted client_accept_;
std::recursive_mutex connection_mutex_;
std::mutex connection_mutex_;
OnServerReceiveMessage message_callback_;
OnTimerOnce on_timer_once_callback_;
OnTimer on_timer_callback_;

@ -21,24 +21,7 @@ namespace ps {
namespace core {
WorkerNode::~WorkerNode() {
MS_LOG(INFO) << "Stop worker node!";
if (!is_already_stopped_.load()) {
is_ready_ = true;
is_timeout_ = true;
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_->StopEventBase();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
}
Stop();
}
bool WorkerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Starting worker node!";
@ -78,19 +61,15 @@ bool WorkerNode::Stop() {
if (!is_already_stopped_.load()) {
is_ready_ = true;
is_timeout_ = true;
is_finish_ = true;
heart_beat_thread_->join();
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_->StopEventBase();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
}
return true;

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#define protected public
#include "ps/core/worker_node.h"
#undef protected
namespace mindspore {
namespace ps {
namespace core {
class TestAbstractNode : public UT::Common {
public:
TestAbstractNode() = default;
virtual ~TestAbstractNode() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestAbstractNode, NextExpectedRankRequestId) {
WorkerNode workerNode;
ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(0));
ASSERT_EQ(2, workerNode.NextExpectedRankRequestId(0));
ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(1));
}
} // namespace core
} // namespace ps
} // namespace mindspore
Loading…
Cancel
Save