added retry and unit test

pull/10689/head
chendongsheng 5 years ago
parent 3ff953bf1d
commit b289c6184a

@ -33,11 +33,12 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
if (!SendMessageSync(client, comm_message)) {
MS_LOG(EXCEPTION) << "Node register timeout!";
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
}
MS_LOG(INFO) << "The node id:" << node_info_.node_id_
<< "is registering to scheduler, the request id is:" << message_meta.request_id();
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is registering to scheduler!";
}
void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
@ -395,7 +396,7 @@ bool AbstractNode::InitClientToScheduler() {
client_to_scheduler_->Init();
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The worker node start a tcp client!";
MS_LOG(INFO) << "The node start a tcp client!";
client_to_scheduler_->Start();
});

@ -129,6 +129,16 @@ bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &ra
}
return true;
}
bool CommUtil::Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds) {
for (size_t attempt = 0; attempt < max_attempts; ++attempt) {
if (func()) {
return true;
}
std::this_thread::sleep_for(std::chrono::milliseconds(interval_milliseconds));
}
return false;
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -45,6 +45,7 @@
#include <sstream>
#include <string>
#include <utility>
#include <thread>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@ -68,6 +69,7 @@ class CommUtil {
static std::string GenerateUUID();
static std::string NodeRoleToString(const NodeRole &role);
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id);
static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds);
private:
static std::random_device rd;

@ -57,7 +57,7 @@ class Node {
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
using MessageCallback = std::function<void()>;
virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0;
virtual bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) = 0;
virtual bool Stop() = 0;
virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0;

@ -105,7 +105,7 @@ void NodeManager::UpdateClusterState() {
}
// 2. update cluster finish state
if (finish_nodes_id_.size() == total_node_num_) {
if (finish_nodes_id_.size() == total_node_num_ || SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
is_cluster_finish_ = true;
is_cluster_ready_ = true;
}
@ -119,7 +119,9 @@ void NodeManager::UpdateClusterState() {
void NodeManager::CheckClusterTimeout() {
if (total_node_num_ != nodes_info_.size()) {
MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout()
<< " seconds,so finish the cluster";
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
<< nodes_info_.size();
current_node_num_ = nodes_info_.size();
is_cluster_timeout_ = true;
}
}

@ -35,6 +35,7 @@
#include "proto/ps.pb.h"
#include "ps/core/node.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
namespace ps {
@ -47,6 +48,7 @@ class NodeManager {
is_cluster_timeout_(false),
is_node_timeout_(false),
total_node_num_(0),
current_node_num_(-1),
next_worker_rank_id_(-1),
next_server_rank_id_(-1) {}
virtual ~NodeManager() = default;
@ -75,6 +77,7 @@ class NodeManager {
std::atomic<bool> is_cluster_timeout_;
std::atomic<bool> is_node_timeout_;
uint32_t total_node_num_;
int32_t current_node_num_;
std::atomic<int> next_worker_rank_id_;
std::atomic<int> next_server_rank_id_;
// worker nodes and server nodes

@ -44,7 +44,7 @@ class SchedulerNode : public Node {
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
~SchedulerNode() override;
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;

@ -30,7 +30,7 @@ bool ServerNode::Start(const uint32_t &timeout) {
StartHeartbeatTimer(client_to_scheduler_);
if (!WaitForStart(timeout)) {
MS_LOG(EXCEPTION) << "Start Worker node timeout!";
MS_LOG(ERROR) << "Start Server node timeout!";
}
MS_LOG(INFO) << "The cluster is ready to use!";

@ -40,7 +40,7 @@ class ServerNode : public AbstractNode {
ServerNode() : server_(nullptr), server_thread_(nullptr) {}
~ServerNode() override;
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;

@ -36,6 +36,8 @@ namespace mindspore {
namespace ps {
namespace core {
event_base *TcpClient::event_base_ = nullptr;
std::mutex TcpClient::event_base_mutex_;
bool TcpClient::is_started_ = false;
TcpClient::TcpClient(const std::string &address, std::uint16_t port)
: event_timeout_(nullptr),
@ -60,10 +62,6 @@ TcpClient::~TcpClient() {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
if (event_base_) {
event_base_free(event_base_);
event_base_ = nullptr;
}
}
std::string TcpClient::GetServerAddress() const { return server_address_; }
@ -234,6 +232,13 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void
}
void TcpClient::Start() {
event_base_mutex_.lock();
if (is_started_) {
event_base_mutex_.unlock();
return;
}
is_started_ = true;
event_base_mutex_.unlock();
MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_dispatch(event_base_);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
@ -260,7 +265,7 @@ void TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
}

@ -35,6 +35,7 @@
#include "ps/core/cluster_config.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
namespace ps {
@ -86,6 +87,9 @@ class TcpClient {
OnTimer on_timer_callback_;
static event_base *event_base_;
static std::mutex event_base_mutex_;
static bool is_started_;
std::mutex connection_mutex_;
std::condition_variable connection_cond_;
event *event_timeout_;

@ -32,7 +32,6 @@
namespace mindspore {
namespace ps {
namespace core {
void TcpConnection::InitConnection() {
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
@ -58,7 +57,7 @@ void TcpConnection::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
@ -304,7 +303,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
if (read == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
}
conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
conn->OnReadHandler(read_buffer, IntToSize(read));
}
}

@ -39,6 +39,7 @@
#include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
namespace ps {

@ -40,7 +40,7 @@ class WorkerNode : public AbstractNode {
WorkerNode() = default;
~WorkerNode() override;
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "ps/core/node.h"
#include "ps/core/scheduler_node.h"
namespace mindspore {
namespace ps {
namespace core {
class TestClusterAvailableTimeout : public UT::Common {
public:
TestClusterAvailableTimeout() = default;
~TestClusterAvailableTimeout() override = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999);
ClusterConfig::set_cluster_available_timeout(3);
SchedulerNode node;
node.Start();
node.Finish();
node.Stop();
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -27,6 +27,18 @@ class TestCommUtil : public UT::Common {
public:
TestCommUtil() = default;
virtual ~TestCommUtil() = default;
struct MockRetry {
bool operator()(std::string mock) {
++count_;
if (count_ > 3) {
return true;
}
return false;
}
int count_{0};
};
void SetUp() override {}
void TearDown() override {}
@ -47,6 +59,14 @@ TEST_F(TestCommUtil, ValidateRankId) {
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2));
}
TEST_F(TestCommUtil, Retry) {
bool const ret = CommUtil::Retry([]() -> bool { return false; }, 5, 100);
EXPECT_FALSE(ret);
MockRetry mock_retry;
bool const mock_ret = CommUtil::Retry([&] { return mock_retry("mock"); }, 5, 100);
EXPECT_TRUE(mock_ret);
}
} // namespace core
} // namespace ps
} // namespace mindspore
Loading…
Cancel
Save