diff --git a/mindspore/ccsrc/ps/constants.h b/mindspore/ccsrc/ps/constants.h index 59e4284587..b6e2824052 100644 --- a/mindspore/ccsrc/ps/constants.h +++ b/mindspore/ccsrc/ps/constants.h @@ -67,7 +67,7 @@ constexpr int64_t kPullCmd = 51; constexpr size_t kInvalidKey = UINT64_MAX; constexpr int64_t kInvalidID = -1; -using DataPtr = std::shared_ptr; +using DataPtr = std::shared_ptr; using VectorPtr = std::shared_ptr>; using Key = uint64_t; using Keys = std::vector; diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 4e464f5387..9b14ce48db 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -281,7 +281,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) if (!Heartbeat(client)) { MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; - if (!CheckSchedulerTimeout() && on_node_event_message_) { + if (CheckSchedulerTimeout() && on_node_event_message_) { MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; is_finish_ = true; @@ -294,6 +294,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval())); } }); + heart_beat_thread_->detach(); } bool AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_node_finish) { @@ -307,6 +308,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_n if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), heartbeat_message.ByteSizeLong())) { MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; + return false; } return true; } @@ -315,9 +317,7 @@ void AbstractNode::UpdateSchedulerTime() { struct timeval current_time {}; (void)gettimeofday(¤t_time, nullptr); scheduler_time_ = current_time; - MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) - << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ - << " update scheduler time, the current time is: " << current_time.tv_sec; + MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec; } bool AbstractNode::CheckSchedulerTimeout() const { @@ -430,10 +430,13 @@ bool AbstractNode::InitClientToScheduler() { MS_LOG(INFO) << "The node start a tcp client!"; client_to_scheduler_->Start(); }); + client_to_scheduler_thread_->detach(); client_to_scheduler_->set_disconnected_callback([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval())); - client_to_scheduler_->Init(); + if (is_ready_.load() == false) { + client_to_scheduler_->Init(); + } }); return client_to_scheduler_->WaitConnected(); } diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index a4503829bd..434df7160e 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -37,7 +37,7 @@ class AbstractNode : public Node { typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr meta, const void *data, size_t size); - using DataPtr = std::shared_ptr; + using DataPtr = std::shared_ptr; using VectorPtr = std::shared_ptr>; bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, diff --git a/mindspore/ccsrc/ps/core/cluster_metadata.h b/mindspore/ccsrc/ps/core/cluster_metadata.h index f27479af94..2039aba0b7 100644 --- a/mindspore/ccsrc/ps/core/cluster_metadata.h +++ b/mindspore/ccsrc/ps/core/cluster_metadata.h @@ -62,7 +62,7 @@ class ClusterMetadata { heartbeat_timeout_(30), cluster_available_timeout_(300), connect_interval_(100), - scheduler_timeout_(3600 * 5) {} + scheduler_timeout_(30) {} uint32_t worker_num_; uint32_t server_num_; // The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. diff --git a/mindspore/ccsrc/ps/core/node_info.h b/mindspore/ccsrc/ps/core/node_info.h index b421cf2ad6..41950ca897 100644 --- a/mindspore/ccsrc/ps/core/node_info.h +++ b/mindspore/ccsrc/ps/core/node_info.h @@ -25,7 +25,7 @@ namespace mindspore { namespace ps { namespace core { -enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT }; +enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT = 2 }; struct NodeInfo { NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 62b54ad840..26f39d1b96 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -105,7 +105,7 @@ void ServerNode::ProcessSendData(std::shared_ptr conn, std::share MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(data); - std::shared_ptr res(new unsigned char[size]); + std::shared_ptr res(new unsigned char[size]); int ret = memcpy_s(res.get(), size, data, size); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; @@ -131,14 +131,18 @@ bool ServerNode::Stop() { if (!is_already_stopped_.load()) { is_already_stopped_ = true; is_finish_ = true; - heart_beat_thread_->join(); + if (heart_beat_thread_->joinable()) { + heart_beat_thread_->join(); + } client_to_scheduler_->Stop(); if (!connected_nodes_.empty()) { for (auto &connected_node : connected_nodes_) { connected_node.second->Stop(); } } - client_to_scheduler_thread_->join(); + if (client_to_scheduler_thread_->joinable()) { + client_to_scheduler_thread_->join(); + } server_->Stop(); server_thread_->join(); } diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index eec706b746..25b0c9cbb9 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -311,7 +311,8 @@ bool TcpClient::SendMessage(std::shared_ptr meta, const Protos &pro } int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH); if (result < 0) { - MS_LOG(EXCEPTION) << "Bufferevent flush failed!"; + MS_LOG(ERROR) << "Bufferevent flush failed!"; + res = false; } bufferevent_unlock(buffer_event_); return res; diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index 1870a49924..d4138c0751 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -63,14 +63,18 @@ bool WorkerNode::Stop() { is_ready_ = true; is_timeout_ = true; is_finish_ = true; - heart_beat_thread_->join(); + if (heart_beat_thread_->joinable()) { + heart_beat_thread_->join(); + } client_to_scheduler_->Stop(); if (!connected_nodes_.empty()) { for (auto &connected_node : connected_nodes_) { connected_node.second->Stop(); } } - client_to_scheduler_thread_->join(); + if (client_to_scheduler_thread_->joinable()) { + client_to_scheduler_thread_->join(); + } is_already_stopped_ = true; } return true; diff --git a/mindspore/ccsrc/ps/parameter_server.cc b/mindspore/ccsrc/ps/parameter_server.cc index 71d330be47..4a7256bd8e 100644 --- a/mindspore/ccsrc/ps/parameter_server.cc +++ b/mindspore/ccsrc/ps/parameter_server.cc @@ -21,6 +21,8 @@ namespace ps { void ParameterServer::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; + server_node_ = std::make_shared(); + core::ClusterMetadata::instance()->Init( PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); @@ -30,14 +32,14 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { return; } Init(func_graph); - server_node_.Start(); - rank_id_ = server_node_.rank_id(); + server_node_->Start(); + rank_id_ = server_node_->rank_id(); PSContext::instance()->SetPSRankId(rank_id_); thread_->join(); SyncEmbeddingTables(); MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; - server_node_.Finish(); - server_node_.Stop(); + server_node_->Finish(); + server_node_->Stop(); MS_LOG(INFO) << "PServer finalized successfully."; } @@ -49,7 +51,14 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) { handler_->Init(); InitOptimInfoBuilders(); - server_node_.set_handler(*handler_); + server_node_->set_handler(*handler_); + server_node_->set_event_callback([&](const core::NodeEvent &event) { + if ((event == core::NodeEvent::CLUSTER_TIMEOUT) || + (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) { + MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!"; + Finalize(); + } + }); thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); GetEmbeddingTableParamPtr(); return true; @@ -496,7 +505,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptruser_cmd()]; (this->*handler_ptr)(data, size, output); - std::shared_ptr res(new unsigned char[output->size()]); + std::shared_ptr res(new unsigned char[output->size()]); MS_LOG(DEBUG) << "The output size is:" << output->size(); if (output->size() > 0) { int ret = memcpy_s(res.get(), output->size(), output->data(), output->size()); @@ -505,7 +514,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptrserver_node_.Response(conn, meta, res, output->size()); + ps_->server_node_->Response(conn, meta, res, output->size()); MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:" << std::chrono::time_point_cast(std::chrono::high_resolution_clock::now()) .time_since_epoch() @@ -682,6 +691,7 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()}; ps_->DoEmbeddingLookup(key, keys, &res_data); + res->resize(res_data.ByteSizeLong()); int ret = memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); diff --git a/mindspore/ccsrc/ps/parameter_server.h b/mindspore/ccsrc/ps/parameter_server.h index fdccac3030..4f312c4b3c 100644 --- a/mindspore/ccsrc/ps/parameter_server.h +++ b/mindspore/ccsrc/ps/parameter_server.h @@ -59,6 +59,7 @@ #include "proto/comm.pb.h" #include "proto/ps.pb.h" #include "ps/core/server_node.h" +#include "ps/core/node.h" namespace mindspore { namespace ps { @@ -82,7 +83,8 @@ class ParameterServer { func_graph_(nullptr), sess_(nullptr), running_(true), - thread_(nullptr) {} + thread_(nullptr), + server_node_(nullptr) {} ~ParameterServer() = default; ParameterServer(const ParameterServer &) = delete; ParameterServer &operator=(const ParameterServer &) = delete; @@ -167,7 +169,7 @@ class ParameterServer { std::condition_variable apply_grads_cv_; std::unique_ptr thread_; - core::ServerNode server_node_; + std::shared_ptr server_node_; std::map embedding_tables_; friend class ServerHandler; diff --git a/mindspore/ccsrc/ps/worker.cc b/mindspore/ccsrc/ps/worker.cc index 42ca12c457..360975c5c2 100644 --- a/mindspore/ccsrc/ps/worker.cc +++ b/mindspore/ccsrc/ps/worker.cc @@ -15,11 +15,13 @@ */ #include "ps/worker.h" +#include "pipeline/jit/pipeline.h" namespace mindspore { namespace ps { void Worker::Run() { std::lock_guard lock(running_mutex_); + core::ClusterMetadata::instance()->Init( PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); @@ -33,6 +35,14 @@ void Worker::Run() { } Initialize(); + worker_node_.set_event_callback([&](const core::NodeEvent &event) { + if ((event == core::NodeEvent::CLUSTER_TIMEOUT) || + (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) { + MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!"; + Finalize(); + exit(0); + } + }); MS_LOG(INFO) << "Worker starts connecting to scheduler and server..."; worker_node_.Start(); MS_LOG(INFO) << "Worker connected successfully."; @@ -86,7 +96,7 @@ void Worker::Push(const std::vector &keys, std::vector addrs, } MS_LOG(INFO) << "The total size is:" << total_size; - while (!IsReadyForPush(keys[0])) { + while (running_ && (!IsReadyForPush(keys[0]))) { continue; } std::vector sizes_int; @@ -109,7 +119,7 @@ void Worker::Push(const std::vector &keys, std::vector addrs, void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { MS_EXCEPTION_IF_NULL(dev_addr); std::vector variables(size / sizeof(float), 0); - while (!IsReadyForPull(key)) { + while (running_ && (!IsReadyForPull(key))) { continue; } PullData({key}, &variables, nullptr, kPullCmd); @@ -214,7 +224,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector & std::string kv_data = embedding_table_meta.SerializeAsString(); - std::shared_ptr res(new unsigned char[kv_data.length()]); + std::shared_ptr res(new unsigned char[kv_data.length()]); int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; @@ -280,7 +290,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector &lookup_ rank_ids.push_back(i); std::string kv_data = messages.at(i).second.SerializeAsString(); - std::shared_ptr res(new unsigned char[kv_data.length()]); + std::shared_ptr res(new unsigned char[kv_data.length()]); int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; @@ -303,7 +313,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector &lookup_ for (auto j = 0; j < message.values_size(); j++) { values->push_back(message.values(j)); } - MS_LOG(DEBUG) << "The embedding resp:" << values; + MS_LOG(DEBUG) << "The embedding resp:" << *values; for (auto k = 0; k < message.keys_size(); k++) { const Key &key = message.keys(k); float *addr = values->data() + value_offset; @@ -358,7 +368,7 @@ void Worker::UpdateEmbeddingTable(const std::vector &keys, const std::vecto rank_ids.push_back(i); std::string kv_data = messages.at(i).second.SerializeAsString(); - std::shared_ptr res(new unsigned char[kv_data.length()]); + std::shared_ptr res(new unsigned char[kv_data.length()]); int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; @@ -378,7 +388,7 @@ void Worker::Finalize() { kvs.add_keys(0); kvs.add_values(0.0f); std::string kv_data = kvs.SerializeAsString(); - std::shared_ptr res(new unsigned char[kv_data.length()]); + std::shared_ptr res(new unsigned char[kv_data.length()]); int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; @@ -619,7 +629,7 @@ void Worker::PushData(const std::vector &keys, const std::vector &va SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {}); } else { std::string kv_data = kvs.SerializeAsString(); - std::shared_ptr res(new unsigned char[kv_data.length()]); + std::shared_ptr res(new unsigned char[kv_data.length()]); int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; @@ -920,7 +930,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa rank_ids.push_back(i); std::string kv_data = messages.at(i).second.SerializeAsString(); - std::shared_ptr res(new unsigned char[kv_data.length()]); + std::shared_ptr res(new unsigned char[kv_data.length()]); int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; @@ -945,7 +955,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa rank_ids.push_back(i); std::string kv_data = messages.at(i).second.SerializeAsString(); - std::shared_ptr res(new unsigned char[kv_data.length()]); + std::shared_ptr res(new unsigned char[kv_data.length()]); int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";