!12929 added event callback

From: @anancds
Reviewed-by: 
Signed-off-by:
pull/12929/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 37560318ef

@ -67,7 +67,7 @@ constexpr int64_t kPullCmd = 51;
constexpr size_t kInvalidKey = UINT64_MAX;
constexpr int64_t kInvalidID = -1;
using DataPtr = std::shared_ptr<unsigned char>;
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = uint64_t;
using Keys = std::vector<Key>;

@ -281,7 +281,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
if (!Heartbeat(client)) {
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
if (!CheckSchedulerTimeout() && on_node_event_message_) {
if (CheckSchedulerTimeout() && on_node_event_message_) {
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
is_finish_ = true;
@ -294,6 +294,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval()));
}
});
heart_beat_thread_->detach();
}
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
@ -307,6 +308,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
heartbeat_message.ByteSizeLong())) {
MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
return false;
}
return true;
}
@ -315,9 +317,7 @@ void AbstractNode::UpdateSchedulerTime() {
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
scheduler_time_ = current_time;
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
<< " update scheduler time, the current time is: " << current_time.tv_sec;
MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec;
}
bool AbstractNode::CheckSchedulerTimeout() const {
@ -430,10 +430,13 @@ bool AbstractNode::InitClientToScheduler() {
MS_LOG(INFO) << "The node start a tcp client!";
client_to_scheduler_->Start();
});
client_to_scheduler_thread_->detach();
client_to_scheduler_->set_disconnected_callback([&]() {
std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval()));
client_to_scheduler_->Init();
if (is_ready_.load() == false) {
client_to_scheduler_->Init();
}
});
return client_to_scheduler_->WaitConnected();
}

@ -37,7 +37,7 @@ class AbstractNode : public Node {
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
using DataPtr = std::shared_ptr<unsigned char>;
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,

@ -62,7 +62,7 @@ class ClusterMetadata {
heartbeat_timeout_(30),
cluster_available_timeout_(300),
connect_interval_(100),
scheduler_timeout_(3600 * 5) {}
scheduler_timeout_(30) {}
uint32_t worker_num_;
uint32_t server_num_;
// The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds.

@ -25,7 +25,7 @@
namespace mindspore {
namespace ps {
namespace core {
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT };
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT = 2 };
struct NodeInfo {
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}

@ -105,7 +105,7 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::shared_ptr<unsigned char> res(new unsigned char[size]);
std::shared_ptr<unsigned char[]> res(new unsigned char[size]);
int ret = memcpy_s(res.get(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
@ -131,14 +131,18 @@ bool ServerNode::Stop() {
if (!is_already_stopped_.load()) {
is_already_stopped_ = true;
is_finish_ = true;
heart_beat_thread_->join();
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_thread_->join();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
server_->Stop();
server_thread_->join();
}

@ -311,7 +311,8 @@ bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &pro
}
int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
if (result < 0) {
MS_LOG(EXCEPTION) << "Bufferevent flush failed!";
MS_LOG(ERROR) << "Bufferevent flush failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;

@ -63,14 +63,18 @@ bool WorkerNode::Stop() {
is_ready_ = true;
is_timeout_ = true;
is_finish_ = true;
heart_beat_thread_->join();
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_thread_->join();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
is_already_stopped_ = true;
}
return true;

@ -21,6 +21,8 @@ namespace ps {
void ParameterServer::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
server_node_ = std::make_shared<core::ServerNode>();
core::ClusterMetadata::instance()->Init(
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
@ -30,14 +32,14 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
return;
}
Init(func_graph);
server_node_.Start();
rank_id_ = server_node_.rank_id();
server_node_->Start();
rank_id_ = server_node_->rank_id();
PSContext::instance()->SetPSRankId(rank_id_);
thread_->join();
SyncEmbeddingTables();
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
server_node_.Finish();
server_node_.Stop();
server_node_->Finish();
server_node_->Stop();
MS_LOG(INFO) << "PServer finalized successfully.";
}
@ -49,7 +51,14 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
handler_->Init();
InitOptimInfoBuilders();
server_node_.set_handler(*handler_);
server_node_->set_handler(*handler_);
server_node_->set_event_callback([&](const core::NodeEvent &event) {
if ((event == core::NodeEvent::CLUSTER_TIMEOUT) ||
(event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) {
MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!";
Finalize();
}
});
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
GetEmbeddingTableParamPtr();
return true;
@ -496,7 +505,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect
auto &handler_ptr = handlers_[meta->user_cmd()];
(this->*handler_ptr)(data, size, output);
std::shared_ptr<unsigned char> res(new unsigned char[output->size()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[output->size()]);
MS_LOG(DEBUG) << "The output size is:" << output->size();
if (output->size() > 0) {
int ret = memcpy_s(res.get(), output->size(), output->data(), output->size());
@ -505,7 +514,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect
}
}
ps_->server_node_.Response(conn, meta, res, output->size());
ps_->server_node_->Response(conn, meta, res, output->size());
MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:"
<< std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now())
.time_since_epoch()
@ -682,6 +691,7 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
*res_data.mutable_keys() = {input.keys().begin(), input.keys().end()};
ps_->DoEmbeddingLookup(key, keys, &res_data);
res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());

@ -59,6 +59,7 @@
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/server_node.h"
#include "ps/core/node.h"
namespace mindspore {
namespace ps {
@ -82,7 +83,8 @@ class ParameterServer {
func_graph_(nullptr),
sess_(nullptr),
running_(true),
thread_(nullptr) {}
thread_(nullptr),
server_node_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
ParameterServer &operator=(const ParameterServer &) = delete;
@ -167,7 +169,7 @@ class ParameterServer {
std::condition_variable apply_grads_cv_;
std::unique_ptr<std::thread> thread_;
core::ServerNode server_node_;
std::shared_ptr<core::ServerNode> server_node_;
std::map<Key, ParameterPtr> embedding_tables_;
friend class ServerHandler;

@ -15,11 +15,13 @@
*/
#include "ps/worker.h"
#include "pipeline/jit/pipeline.h"
namespace mindspore {
namespace ps {
void Worker::Run() {
std::lock_guard<std::mutex> lock(running_mutex_);
core::ClusterMetadata::instance()->Init(
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
@ -33,6 +35,14 @@ void Worker::Run() {
}
Initialize();
worker_node_.set_event_callback([&](const core::NodeEvent &event) {
if ((event == core::NodeEvent::CLUSTER_TIMEOUT) ||
(event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) {
MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!";
Finalize();
exit(0);
}
});
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
worker_node_.Start();
MS_LOG(INFO) << "Worker connected successfully.";
@ -86,7 +96,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
}
MS_LOG(INFO) << "The total size is:" << total_size;
while (!IsReadyForPush(keys[0])) {
while (running_ && (!IsReadyForPush(keys[0]))) {
continue;
}
std::vector<int> sizes_int;
@ -109,7 +119,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
MS_EXCEPTION_IF_NULL(dev_addr);
std::vector<float> variables(size / sizeof(float), 0);
while (!IsReadyForPull(key)) {
while (running_ && (!IsReadyForPull(key))) {
continue;
}
PullData({key}, &variables, nullptr, kPullCmd);
@ -214,7 +224,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &
std::string kv_data = embedding_table_meta.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@ -280,7 +290,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@ -303,7 +313,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
for (auto j = 0; j < message.values_size(); j++) {
values->push_back(message.values(j));
}
MS_LOG(DEBUG) << "The embedding resp:" << values;
MS_LOG(DEBUG) << "The embedding resp:" << *values;
for (auto k = 0; k < message.keys_size(); k++) {
const Key &key = message.keys(k);
float *addr = values->data() + value_offset;
@ -358,7 +368,7 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@ -378,7 +388,7 @@ void Worker::Finalize() {
kvs.add_keys(0);
kvs.add_values(0.0f);
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@ -619,7 +629,7 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va
SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
} else {
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@ -920,7 +930,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@ -945,7 +955,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";

Loading…
Cancel
Save