|
|
|
@ -32,6 +32,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|
|
|
|
CommMessage comm_message;
|
|
|
|
|
*comm_message.mutable_pb_meta() = {message_meta};
|
|
|
|
|
comm_message.set_data(register_message.SerializeAsString());
|
|
|
|
|
comm_message.set_user_cmd("");
|
|
|
|
|
if (!SendMessageSync(client, comm_message)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
|
|
|
|
<< " the node id:" << node_info_.node_id_ << " register timeout!";
|
|
|
|
@ -54,11 +55,12 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
|
|
|
|
|
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) {
|
|
|
|
|
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) {
|
|
|
|
|
if (node_role != NodeRole::SERVER) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
|
|
|
|
uint64_t request_id = ++next_request_id_;
|
|
|
|
|
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
|
|
|
|
|
|
|
|
|
@ -69,9 +71,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &
|
|
|
|
|
message_meta.set_rank_id(node_info_.rank_id_);
|
|
|
|
|
message_meta.set_role(node_info_.node_role_);
|
|
|
|
|
|
|
|
|
|
CommMessage comm_message;
|
|
|
|
|
*comm_message.mutable_pb_meta() = {message_meta};
|
|
|
|
|
comm_message.set_data(message);
|
|
|
|
|
auto client = GetOrCreateTcpClient((*it).first.second);
|
|
|
|
|
client->SendMessage(comm_message);
|
|
|
|
|
}
|
|
|
|
@ -84,26 +84,26 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
|
|
|
|
|
on_node_event_message_ = on_node_event_message;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
|
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
|
|
|
|
const uint32_t &timeout) {
|
|
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
|
|
|
|
|
|
|
|
|
MessageMeta message_meta;
|
|
|
|
|
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
|
|
|
|
message_meta.set_rank_id(node_info_.rank_id_);
|
|
|
|
|
message_meta.set_role(node_info_.node_role_);
|
|
|
|
|
|
|
|
|
|
CommMessage comm_message;
|
|
|
|
|
*comm_message.mutable_pb_meta() = {message_meta};
|
|
|
|
|
comm_message.set_data(message);
|
|
|
|
|
auto client = GetOrCreateTcpClient(rank_id);
|
|
|
|
|
return SendMessageSync(client, comm_message, timeout);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
|
|
|
|
const std::vector<std::string> &data, const uint32_t &timeout) {
|
|
|
|
|
const std::vector<CommMessage> &data, const uint32_t &timeout) {
|
|
|
|
|
uint64_t request_id = ++next_request_id_;
|
|
|
|
|
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
|
|
|
|
|
|
|
|
@ -121,9 +121,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|
|
|
|
message_meta.set_rank_id(node_info_.rank_id_);
|
|
|
|
|
message_meta.set_role(node_info_.node_role_);
|
|
|
|
|
|
|
|
|
|
CommMessage comm_message;
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
|
|
|
|
|
*comm_message.mutable_pb_meta() = {message_meta};
|
|
|
|
|
comm_message.set_data(data.at(it));
|
|
|
|
|
|
|
|
|
|
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
|
|
|
|
client->SendMessage(comm_message);
|
|
|
|
@ -133,19 +132,21 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|
|
|
|
return Wait(request_id, timeout);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
|
|
|
|
std::string *output, const uint32_t &timeout) {
|
|
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
|
|
|
|
CommMessage *output, const uint32_t &timeout) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
|
|
|
|
|
|
|
|
|
uint64_t request_id = ++next_request_id_;
|
|
|
|
|
message_tracker_[request_id] = std::make_pair(1, 0);
|
|
|
|
|
set_message_callback(request_id, [&]() {
|
|
|
|
|
receive_messages_mutex_.lock();
|
|
|
|
|
auto res = receive_messages_[request_id];
|
|
|
|
|
*output = res[rank_id].data();
|
|
|
|
|
*output = res[rank_id];
|
|
|
|
|
receive_messages_.erase(request_id);
|
|
|
|
|
receive_messages_mutex_.unlock();
|
|
|
|
|
});
|
|
|
|
@ -156,9 +157,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|
|
|
|
message_meta.set_rank_id(node_info_.rank_id_);
|
|
|
|
|
message_meta.set_role(node_info_.node_role_);
|
|
|
|
|
|
|
|
|
|
CommMessage comm_message;
|
|
|
|
|
*comm_message.mutable_pb_meta() = {message_meta};
|
|
|
|
|
comm_message.set_data(message);
|
|
|
|
|
auto client = GetOrCreateTcpClient(rank_id);
|
|
|
|
|
client->SendMessage(comm_message);
|
|
|
|
|
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
|
|
|
@ -167,7 +166,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
|
|
|
|
const std::vector<std::string> &data, std::vector<std::string> *output,
|
|
|
|
|
const std::vector<CommMessage> &data, std::vector<CommMessage> *output,
|
|
|
|
|
const uint32_t &timeout) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
uint64_t request_id = ++next_request_id_;
|
|
|
|
@ -183,7 +182,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|
|
|
|
receive_messages_mutex_.lock();
|
|
|
|
|
auto res = receive_messages_[request_id];
|
|
|
|
|
for (size_t it = 0; it < len; ++it) {
|
|
|
|
|
(*output).push_back(res[rank_ids.at(it)].data());
|
|
|
|
|
(*output).push_back(res[rank_ids.at(it)]);
|
|
|
|
|
}
|
|
|
|
|
receive_messages_.erase(request_id);
|
|
|
|
|
receive_messages_mutex_.unlock();
|
|
|
|
@ -200,9 +199,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|
|
|
|
message_meta.set_rank_id(node_info_.rank_id_);
|
|
|
|
|
message_meta.set_role(node_info_.node_role_);
|
|
|
|
|
|
|
|
|
|
CommMessage comm_message;
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
|
|
|
|
|
*comm_message.mutable_pb_meta() = {message_meta};
|
|
|
|
|
comm_message.set_data(data.at(it));
|
|
|
|
|
|
|
|
|
|
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
|
|
|
|
client->SendMessage(comm_message);
|
|
|
|
@ -223,37 +221,37 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|
|
|
|
const std::string &message) {
|
|
|
|
|
const CommMessage &message) {
|
|
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
|
|
|
|
|
|
|
|
|
MessageMeta message_meta;
|
|
|
|
|
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
|
|
|
|
|
message_meta.set_rank_id(node_info_.rank_id_);
|
|
|
|
|
message_meta.set_role(node_info_.node_role_);
|
|
|
|
|
|
|
|
|
|
CommMessage comm_message;
|
|
|
|
|
*comm_message.mutable_pb_meta() = {message_meta};
|
|
|
|
|
comm_message.set_data(message);
|
|
|
|
|
auto client = GetOrCreateTcpClient(rank_id);
|
|
|
|
|
return SendMessageAsync(client, comm_message);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
|
|
|
|
|
const uint32_t &rank_id, std::string *output) {
|
|
|
|
|
const uint32_t &rank_id, CommMessage *output) {
|
|
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
|
|
|
|
|
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
|
|
|
|
|
*output = received_data_[std::make_pair(rank_id, rank_request_id)].data();
|
|
|
|
|
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
|
|
|
|
|
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
|
|
|
|
} else {
|
|
|
|
|
set_receive_callback(rank_id, rank_request_id, [=]() {
|
|
|
|
|
receive_callbacks_mutex_.lock();
|
|
|
|
|
*output = received_data_[std::make_pair(rank_id, 1)].data();
|
|
|
|
|
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
|
|
|
|
|
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
|
|
|
|
receive_callbacks_mutex_.unlock();
|
|
|
|
|
});
|
|
|
|
@ -415,21 +413,12 @@ bool AbstractNode::InitClientToScheduler() {
|
|
|
|
|
uint16_t scheduler_port = ClusterConfig::scheduler_port();
|
|
|
|
|
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
|
|
|
|
|
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
|
|
|
|
switch (message.pb_meta().cmd()) {
|
|
|
|
|
case NodeCommand::HEARTBEAT:
|
|
|
|
|
ProcessHeartbeatResp(message);
|
|
|
|
|
break;
|
|
|
|
|
case NodeCommand::REGISTER:
|
|
|
|
|
ProcessRegisterResp(message);
|
|
|
|
|
break;
|
|
|
|
|
case NodeCommand::FETCH_SERVER:
|
|
|
|
|
ProcessFetchServersResp(message);
|
|
|
|
|
break;
|
|
|
|
|
case NodeCommand::FINISH:
|
|
|
|
|
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
|
|
|
|
if (handlers_.count(message.pb_meta().cmd()) == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
|
|
|
|
}
|
|
|
|
|
if (handlers_[message.pb_meta().cmd()] != nullptr) {
|
|
|
|
|
const auto &handler_ptr = handlers_[message.pb_meta().cmd()];
|
|
|
|
|
(this->*handler_ptr)(message);
|
|
|
|
|
}
|
|
|
|
|
NotifyMessageArrival(message);
|
|
|
|
|
});
|
|
|
|
@ -607,6 +596,13 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
|
|
|
|
|
}
|
|
|
|
|
return rank_request_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AbstractNode::InitCommandHandler() {
|
|
|
|
|
handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
|
|
|
|
|
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
|
|
|
|
|
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
|
|
|
|
|
handlers_[NodeCommand::FINISH] = nullptr;
|
|
|
|
|
}
|
|
|
|
|
} // namespace core
|
|
|
|
|
} // namespace ps
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|