From 09a15be89342c56ed74aa89add10a5dccc5a2a29 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Sat, 16 Jan 2021 11:18:17 +0800 Subject: [PATCH] protobuf change bytes to Any --- mindspore/ccsrc/ps/core/abstract_node.cc | 74 +++++++++++------------ mindspore/ccsrc/ps/core/abstract_node.h | 20 +++--- mindspore/ccsrc/ps/core/protos/comm.proto | 3 +- mindspore/ccsrc/ps/core/protos/ps.proto | 31 +++++++++- mindspore/ccsrc/ps/core/scheduler_node.cc | 27 ++++----- mindspore/ccsrc/ps/core/scheduler_node.h | 8 ++- mindspore/ccsrc/ps/core/server_node.cc | 1 + mindspore/ccsrc/ps/core/server_node.h | 3 - mindspore/ccsrc/ps/core/worker_node.cc | 1 + mindspore/ccsrc/ps/core/worker_node.h | 1 - 10 files changed, 97 insertions(+), 72 deletions(-) diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 293a1fe149..daac34850e 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -32,6 +32,7 @@ void AbstractNode::Register(const std::shared_ptr &client) { CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; comm_message.set_data(register_message.SerializeAsString()); + comm_message.set_user_cmd(""); if (!SendMessageSync(client, comm_message)) { MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " register timeout!"; @@ -54,11 +55,12 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; } -bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) { +bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) { if (node_role != NodeRole::SERVER) { MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; } + CommMessage &comm_message = const_cast(message); uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); @@ -69,9 +71,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string & message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient((*it).first.second); client->SendMessage(comm_message); } @@ -84,26 +84,26 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me on_node_event_message_ = on_node_event_message; } -bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, const uint32_t &timeout) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } + CommMessage &comm_message = const_cast(message); + MessageMeta message_meta; message_meta.set_cmd(NodeCommand::SEND_DATA); message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); return SendMessageSync(client, comm_message, timeout); } bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, const uint32_t &timeout) { + const std::vector &data, const uint32_t &timeout) { uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(data.size(), 0); @@ -121,9 +121,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; + CommMessage &comm_message = const_cast(data.at(it)); *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(data.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it)); client->SendMessage(comm_message); @@ -133,19 +132,21 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & return Wait(request_id, timeout); } -bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - std::string *output, const uint32_t &timeout) { +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, + CommMessage *output, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(output); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } + CommMessage &comm_message = const_cast(message); + uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(1, 0); set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; - *output = res[rank_id].data(); + *output = res[rank_id]; receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -156,9 +157,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); client->SendMessage(comm_message); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) @@ -167,7 +166,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, } bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, std::vector *output, + const std::vector &data, std::vector *output, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(output); uint64_t request_id = ++next_request_id_; @@ -183,7 +182,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; for (size_t it = 0; it < len; ++it) { - (*output).push_back(res[rank_ids.at(it)].data()); + (*output).push_back(res[rank_ids.at(it)]); } receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); @@ -200,9 +199,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; + CommMessage &comm_message = const_cast(data.at(it)); *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(data.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it)); client->SendMessage(comm_message); @@ -223,37 +221,37 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { } uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - const std::string &message) { + const CommMessage &message) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } + CommMessage &comm_message = const_cast(message); + MessageMeta message_meta; message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); return SendMessageAsync(client, comm_message); } std::pair AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, - const uint32_t &rank_id, std::string *output) { + const uint32_t &rank_id, CommMessage *output) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { - *output = received_data_[std::make_pair(rank_id, rank_request_id)].data(); + *output = received_data_[std::make_pair(rank_id, rank_request_id)]; received_data_.erase(std::make_pair(rank_id, rank_request_id)); } else { set_receive_callback(rank_id, rank_request_id, [=]() { receive_callbacks_mutex_.lock(); - *output = received_data_[std::make_pair(rank_id, 1)].data(); + *output = received_data_[std::make_pair(rank_id, rank_request_id)]; received_data_.erase(std::make_pair(rank_id, rank_request_id)); receive_callbacks_mutex_.unlock(); }); @@ -415,21 +413,12 @@ bool AbstractNode::InitClientToScheduler() { uint16_t scheduler_port = ClusterConfig::scheduler_port(); client_to_scheduler_ = std::make_shared(scheduler_host, scheduler_port); client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { - switch (message.pb_meta().cmd()) { - case NodeCommand::HEARTBEAT: - ProcessHeartbeatResp(message); - break; - case NodeCommand::REGISTER: - ProcessRegisterResp(message); - break; - case NodeCommand::FETCH_SERVER: - ProcessFetchServersResp(message); - break; - case NodeCommand::FINISH: - MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; - break; - default: - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + if (handlers_.count(message.pb_meta().cmd()) == 0) { + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + if (handlers_[message.pb_meta().cmd()] != nullptr) { + const auto &handler_ptr = handlers_[message.pb_meta().cmd()]; + (this->*handler_ptr)(message); } NotifyMessageArrival(message); }); @@ -607,6 +596,13 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) { } return rank_request_id; } + +void AbstractNode::InitCommandHandler() { + handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp; + handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp; + handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp; + handlers_[NodeCommand::FINISH] = nullptr; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index eea8eb773d..448e36489a 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -34,23 +34,25 @@ class AbstractNode : public Node { AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} ~AbstractNode() override = default; - bool Broadcast(const enum NodeRole &node_role, const std::string &message, + typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message); + + bool Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout = kCommTimeoutInSeconds); void set_event_callback(const OnNodeEventMessage &on_node_event_message); - bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output, + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, - std::vector *output, const uint32_t &timeout = kCommTimeoutInSeconds); + bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + std::vector *output, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); - uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message); + uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message); std::pair CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - std::string *output); + CommMessage *output); bool CollectiveWait(std::pair request_id, const uint32_t &timeout = kCommTimeoutInSeconds); protected: @@ -78,6 +80,7 @@ class AbstractNode : public Node { void RunReceiveCallback(const CommMessage &message); uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); uint64_t NextActualRankRequestId(const uint32_t &rank_id); + void InitCommandHandler(); std::unique_ptr heart_beat_thread_; std::unique_ptr client_to_scheduler_thread_; @@ -115,6 +118,7 @@ class AbstractNode : public Node { std::unordered_map actual_rank_request_ids_; std::mutex rank_request_ids_mutex; timeval scheduler_time_; + std::unordered_map handlers_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 4e24de8c58..81d1013712 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -95,5 +95,6 @@ message FinishMessage { message CommMessage { MessageMeta pb_meta = 1; bytes data = 2; + // User-defined commands + bytes user_cmd = 3; } - diff --git a/mindspore/ccsrc/ps/core/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto index 9ae31a94c1..7f293663a1 100644 --- a/mindspore/ccsrc/ps/core/protos/ps.proto +++ b/mindspore/ccsrc/ps/core/protos/ps.proto @@ -14,17 +14,42 @@ * limitations under the License. */ syntax = "proto3"; -package mindspore.ps.core; +package mindspore.ps; option optimize_for = LITE_RUNTIME; -enum PSCommand { +message Command { + CommandCode cmd = 1; +} + +enum CommandCode { PUSH = 0; PULL = 1; INIT_EMBEDDING_TABLE = 2; + INIT_WEIGHT = 3; + INIT_WEIGHT_TO_OPTIM_ID = 4; + INIT_INPUTS_SHAPE = 5; + CHECK_READY_FOR_PUSH = 6; + CHECK_READY_FOR_PULL = 7; + EMBEDDING_LOOKUP = 8; + UPDATE_EMBEDDING = 9; + FINALIZE = 10; } message KVMessage { - PSCommand command = 1; repeated int32 keys = 2; repeated float values = 3; + repeated int32 len = 4; +} + +message EmbeddingTableMeta { + uint64 key = 1; + repeated uint64 input_shape = 2; + repeated uint64 indices_shape = 3; + repeated uint64 output_shape = 4; +} + +message EmbeddingTableLookup { + uint64 key = 2; + repeated int32 keys = 3; + repeated float values = 4; } \ No newline at end of file diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index d84fc77dc4..a3a38519fb 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -67,6 +67,7 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr server, std::sha } void SchedulerNode::Initialize() { + InitCommandHandler(); CreateTcpServer(); is_already_stopped_ = false; node_info_.node_id_ = CommUtil::GenerateUUID(); @@ -75,6 +76,13 @@ void SchedulerNode::Initialize() { << ", the node id is:" << node_info_.node_id_; } +void SchedulerNode::InitCommandHandler() { + handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat; + handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister; + handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish; + handlers_[NodeCommand::FETCH_SERVER] = &SchedulerNode::ProcessFetchServers; +} + void SchedulerNode::CreateTcpServer() { node_manager_.InitNodeNum(); @@ -82,22 +90,11 @@ void SchedulerNode::CreateTcpServer() { uint32_t scheduler_port = ClusterConfig::scheduler_port(); server_ = std::make_shared(scheduler_host, scheduler_port); server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr message) { - switch (message->pb_meta().cmd()) { - case NodeCommand::HEARTBEAT: - ProcessHeartbeat(server_, conn, message); - break; - case NodeCommand::REGISTER: - ProcessRegister(server_, conn, message); - break; - case NodeCommand::FINISH: - ProcessFinish(server_, conn, message); - break; - case NodeCommand::FETCH_SERVER: - ProcessFetchServers(server_, conn, message); - break; - default: - MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; + if (handlers_.count(message->pb_meta().cmd()) == 0) { + MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; } + const auto &handler_ptr = handlers_[message->pb_meta().cmd()]; + (this->*handler_ptr)(server_, conn, message); }); server_->Init(); diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index a476caae53..1c89d2398d 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -25,29 +25,32 @@ #include #include #include +#include #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" #include "ps/core/node_manager.h" #include "ps/core/node.h" -#include "utils/log_adapter.h" namespace mindspore { namespace ps { namespace core { - class SchedulerNode : public Node { public: SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} ~SchedulerNode() override; + typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message); + bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; private: void Initialize(); + void InitCommandHandler(); void CreateTcpServer(); void ProcessHeartbeat(std::shared_ptr server, std::shared_ptr conn, std::shared_ptr message); @@ -62,6 +65,7 @@ class SchedulerNode : public Node { std::shared_ptr server_; std::unique_ptr scheduler_thread_; std::unique_ptr update_state_thread_; + std::unordered_map handlers_; NodeManager node_manager_; }; diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 08d0b280b8..28d0957067 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -92,6 +92,7 @@ void ServerNode::Initialize() { node_info_.port_ = server_->BoundPort(); MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " is generate uuid is:" << node_info_.node_id_; + InitCommandHandler(); if (!InitClientToScheduler()) { MS_LOG(EXCEPTION) << "Server node init client timeout!"; } diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 2a0d70e82b..086358f56e 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -24,13 +24,10 @@ #include #include -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" #include "ps/core/abstract_node.h" -#include "utils/log_adapter.h" namespace mindspore { namespace ps { diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index ee162e070b..1870a49924 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -50,6 +50,7 @@ void WorkerNode::Initialize() { node_info_.node_role_ = NodeRole::WORKER; MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_; + InitCommandHandler(); if (!InitClientToScheduler()) { MS_LOG(EXCEPTION) << "Worker node init client timeout!"; } diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h index a1343aa362..8608ae430a 100644 --- a/mindspore/ccsrc/ps/core/worker_node.h +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -28,7 +28,6 @@ #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" #include "ps/core/abstract_node.h" -#include "utils/log_adapter.h" namespace mindspore { namespace ps {