added server node

pull/9843/head
chendongsheng 5 years ago
parent 25171b454a
commit 2d2bf2d0ee

@ -5,6 +5,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc")
list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "util.cc")
list(REMOVE_ITEM _PS_SRC_FILES "embedding_table_shard_metadata.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc")
@ -16,6 +17,8 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc")
endif ()
if (NOT ENABLE_D)

@ -0,0 +1,212 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/abstract_node.h"
namespace mindspore {
namespace ps {
namespace core {
void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::REGISTER);
RegisterMessage register_message;
register_message.set_node_id(node_info_.node_id_);
register_message.set_role(node_info_.node_role_);
register_message.set_ip(node_info_.ip_);
register_message.set_port(node_info_.port_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
if (!SendMessageSync(client, comm_message)) {
MS_LOG(EXCEPTION) << "Node register timeout!";
}
MS_LOG(INFO) << "The node id:" << node_info_.node_id_
<< "is registering to scheduler, the request id is:" << message_meta.request_id();
}
void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
RegisterRespMessage register_resp_message;
register_resp_message.ParseFromString(message.data());
if (register_resp_message.node_id() != node_info_.node_id_) {
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
<< " is not match the current node id:" << node_info_.node_id_;
}
node_info_.rank_id_ = register_resp_message.rank_id();
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
}
bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message);
}
return Wait(request_id, timeout);
}
void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_message) {
on_node_event_message_ = on_node_event_message;
}
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
<< " begin send heartbeat to the scheduler!";
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
while (!is_finish_.load()) {
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT);
HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(heartbeat_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
}
}
});
heart_beat_thread_->detach();
}
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromString(message.data());
is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) {
wait_start_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
}
is_finish_ = heartbeat_resp_message.is_cluster_finish();
if (is_finish_.load()) {
wait_finish_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
}
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
if (is_timeout_ && on_node_event_message_) {
is_ready_ = true;
wait_start_cond_.notify_all();
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
}
}
void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FETCH_SERVER);
CommMessage message;
*message.mutable_pb_meta() = {meta};
if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
}
}
void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
FetchServersRespMessage fetch_servers_resp_message;
fetch_servers_resp_message.ParseFromString(message.data());
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
}
MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size();
}
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FINISH);
FinishMessage finish_message;
finish_message.set_node_id(node_info_.node_id_);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(finish_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Disconnect timeout!";
}
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
return WaitForDisconnect(timeout);
}
bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
if (is_finish_.load()) {
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
}
return is_finish_.load();
});
return res;
}
bool AbstractNode::InitClientToScheduler() {
std::string scheduler_host = ClusterConfig::scheduler_host();
uint16_t scheduler_port = ClusterConfig::scheduler_port();
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::HEARTBEAT:
ProcessHeartbeatResp(message);
break;
case NodeCommand::REGISTER:
ProcessRegisterResp(message);
break;
case NodeCommand::FETCH_SERVER:
ProcessFetchServersResp(message);
break;
case NodeCommand::FINISH:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
NotifyMessageArrival(message);
});
client_to_scheduler_->Init();
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The worker node start a tcp client!";
client_to_scheduler_->Start();
});
client_to_scheduler_thread_->detach();
client_to_scheduler_->set_disconnected_callback([&]() {
std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval()));
client_to_scheduler_->Stop();
client_to_scheduler_->Init();
});
return client_to_scheduler_->WaitConnected();
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
#include <utility>
#include <string>
#include <memory>
#include "ps/core/node.h"
namespace mindspore {
namespace ps {
namespace core {
class AbstractNode : public Node {
public:
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
~AbstractNode() override = default;
bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
protected:
void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
void Heartbeat(const std::shared_ptr<TcpClient> &client);
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler();
std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
std::shared_ptr<TcpClient> client_to_scheduler_;
OnNodeEventMessage on_node_event_message_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_

@ -31,6 +31,8 @@ uint32_t ClusterConfig::heartbeat_interval_ = 3;
uint32_t ClusterConfig::heartbeat_timeout_ = 30;
// Timeout period for cluster preparation is 300 seconds.
uint32_t ClusterConfig::cluster_available_timeout_ = 300;
// The timeout period for the client to connect to the server is 100ms.
uint32_t ClusterConfig::connect_interval_ = 100;
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num,
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) {
@ -69,6 +71,9 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa
cluster_available_timeout_ = cluster_available_timeout;
}
uint32_t ClusterConfig::connect_interval() { return connect_interval_; }
void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; }
} // namespace core
} // namespace ps
} // namespace mindspore

@ -42,6 +42,8 @@ class ClusterConfig {
static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout);
static uint32_t cluster_available_timeout();
static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout);
static uint32_t connect_interval();
static void set_connect_interval(const uint32_t &connect_interval);
private:
static uint32_t worker_num_;
@ -51,6 +53,7 @@ class ClusterConfig {
static uint16_t scheduler_port_;
static uint32_t heartbeat_timeout_;
static uint32_t cluster_available_timeout_;
static uint32_t connect_interval_;
};
} // namespace core
} // namespace ps

@ -19,78 +19,11 @@
namespace mindspore {
namespace ps {
namespace core {
void Node::Heartbeat(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
<< " begin send heartbeat to the scheduler!";
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
while (!is_finish_.load()) {
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT);
HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(heartbeat_message.SerializeAsString());
SendMessageAsync(client, message);
}
});
heart_beat_thread_->detach();
}
void Node::ProcessHeartbeatResp(const CommMessage &message) {
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromString(message.data());
is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) {
wait_start_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
}
is_finish_ = heartbeat_resp_message.is_cluster_finish();
if (is_finish_.load()) {
wait_finish_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
}
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
if (is_timeout_ && on_node_event_message_) {
is_ready_ = true;
wait_start_cond_.notify_all();
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
}
}
void Node::FetchServers(const std::shared_ptr<TcpClient> &client) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FETCH_SERVER);
CommMessage message;
*message.mutable_pb_meta() = {meta};
if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
}
}
void Node::ProcessFetchServersResp(const CommMessage &message) {
FetchServersRespMessage fetch_servers_resp_message;
fetch_servers_resp_message.ParseFromString(message.data());
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
}
MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size();
}
std::string Node::node_id() const { return node_info_.node_id_; }
uint32_t Node::rank_id() const { return node_info_.rank_id_; }
void Node::set_callback(const OnNodeEventMessage &on_node_event_message) {
on_node_event_message_ = on_node_event_message;
}
NodeRole Node::role() const { return node_info_.node_role_; }
bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
@ -147,6 +80,7 @@ bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -156,7 +90,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
comm_message_resp = &res[rank_id];
*comm_message_resp = res[rank_id];
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
@ -164,6 +98,8 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
@ -175,6 +111,7 @@ bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const s
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
@ -213,23 +150,6 @@ bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids
return Wait(request_id, timeout);
}
bool Node::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FINISH);
FinishMessage finish_message;
finish_message.set_node_id(node_info_.node_id_);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(finish_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Disconnect timeout!";
}
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
return WaitForDisconnect(timeout);
}
bool Node::WaitForStart(const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(wait_start_mutex_);
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
@ -242,17 +162,6 @@ bool Node::WaitForStart(const uint32_t &timeout) {
return res;
}
bool Node::WaitForDisconnect(const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
if (is_finish_.load()) {
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
}
return is_finish_.load();
});
return res;
}
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
@ -268,15 +177,6 @@ void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const Comm
client->SendMessage(message);
}
void Node::NotifyMessageArrival(const CommMessage &message) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
const MessageMeta &message_meta = message.pb_meta();
uint64_t request_id = message_meta.request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) {
std::lock_guard<std::mutex> lock(client_mutex_);
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
@ -292,6 +192,7 @@ const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id)
switch (message.pb_meta().cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendDataResp(message);
RunMessageCallback(message.pb_meta().request_id());
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
@ -317,13 +218,13 @@ void Node::ProcessSendDataResp(const CommMessage &message) {
res.insert(std::make_pair(rank_id, message));
receive_messages_[request_id] = res;
}
RunMessageCallback(request_id);
}
void Node::RunMessageCallback(const uint64_t &request_id) {
message_callbacks_mutex_.lock();
if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) {
// When receiving a message's response, Then compare with the desired number of responses,
// If they are equal, then call the callback function
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
auto it = message_callbacks_.find(request_id);
if (it != message_callbacks_.end()) {
message_callbacks_mutex_.unlock();
@ -346,6 +247,15 @@ void Node::set_message_callback(const uint64_t &request_id, const MessageCallbac
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
message_callbacks_[request_id] = message_callback;
}
void Node::NotifyMessageArrival(const CommMessage &message) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
const MessageMeta &message_meta = message.pb_meta();
uint64_t request_id = message_meta.request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -52,8 +52,7 @@ class Node {
is_timeout_(false),
is_already_stopped_(true),
is_already_finished_(false),
next_request_id_(0),
heart_beat_thread_(nullptr) {}
next_request_id_(0) {}
virtual ~Node() = default;
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
@ -63,9 +62,10 @@ class Node {
virtual bool Stop() = 0;
virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0;
void set_callback(const OnNodeEventMessage &on_node_event_message);
std::string node_id() const;
uint32_t rank_id() const;
NodeRole role() const;
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
@ -73,27 +73,21 @@ class Node {
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
void Heartbeat(const std::shared_ptr<TcpClient> &client);
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForStart(const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
void NotifyMessageArrival(const CommMessage &message);
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
void ProcessSendDataResp(const CommMessage &message);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
void NotifyMessageArrival(const CommMessage &message);
NodeInfo node_info_;
std::atomic<bool> is_ready_;
@ -102,9 +96,6 @@ class Node {
std::atomic<bool> is_already_stopped_;
std::atomic<bool> is_already_finished_;
std::atomic_uint64_t next_request_id_;
std::unique_ptr<std::thread> heart_beat_thread_;
OnNodeEventMessage on_node_event_message_;
// <NodeRole,rank_id>-><ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
@ -132,5 +123,4 @@ class Node {
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_

@ -0,0 +1,145 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/server_node.h"
namespace mindspore {
namespace ps {
namespace core {
ServerNode::~ServerNode() {
MS_LOG(INFO) << "Stop server node!";
if (!is_already_stopped_.load()) {
server_->Stop();
client_to_scheduler_->Stop();
client_to_scheduler_->StopEventBase();
if (server_thread_->joinable()) {
server_thread_->join();
}
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
is_already_stopped_ = true;
}
}
bool ServerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Start server node!";
Initialize();
Register(client_to_scheduler_);
Heartbeat(client_to_scheduler_);
if (!WaitForStart(timeout)) {
MS_LOG(EXCEPTION) << "Start Worker node timeout!";
}
MS_LOG(INFO) << "The cluster is ready to use!";
// If the cluster is ready to use, then Get the address of all the servers
if (!is_timeout_.load()) {
FetchServers(client_to_scheduler_);
MS_LOG(INFO) << "Server node get all the servers address successful!";
}
MS_LOG(INFO) << "Start the node is successful!";
return true;
}
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
const std::string &message) {
auto &meta = const_cast<MessageMeta &>(message_meta);
meta.set_role(node_info_.node_role_);
meta.set_rank_id(node_info_.rank_id_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {meta};
comm_message.set_data(message);
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
}
void ServerNode::CreateTcpServer() {
std::string interface;
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendData(server, conn, message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
});
server_->Init();
server_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The server node start a tcp server!";
server_->Start();
});
server_thread_->detach();
}
void ServerNode::Initialize() {
CreateTcpServer();
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = NodeRole::SERVER;
node_info_.ip_ = server_->BoundIp();
node_info_.port_ = server_->BoundPort();
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " is generate uuid is:" << node_info_.node_id_;
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Server node init client timeout!";
}
MS_LOG(INFO) << "Server node init client successful!";
}
void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
if (request_handler_) {
request_handler_(server, conn, message.pb_meta(), message.data());
}
}
bool ServerNode::Stop() {
MS_LOG(INFO) << "Stop server node!";
if (!is_already_stopped_.load()) {
server_->Stop();
client_to_scheduler_->Stop();
client_to_scheduler_->StopEventBase();
if (server_thread_->joinable()) {
server_thread_->join();
}
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
}
return true;
}
bool ServerNode::Finish(const uint32_t &timeout) {
std::lock_guard<std::mutex> lock(finish_mutex_);
if (is_already_finished_) {
MS_LOG(INFO) << "Server node already finish!";
return true;
}
is_already_finished_ = true;
return Disconnect(client_to_scheduler_, timeout);
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "ps/core/abstract_node.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace core {
class ServerNode : public AbstractNode {
public:
ServerNode() : server_(nullptr), server_thread_(nullptr) {}
~ServerNode() override;
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn,
const MessageMeta message_meta, const std::string &message)>;
void set_handler(const RequestHandler &handler);
void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
const std::string &message);
private:
void CreateTcpServer();
void Initialize();
void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;
RequestHandler request_handler_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_

@ -35,7 +35,6 @@
namespace mindspore {
namespace ps {
namespace core {
event_base *TcpClient::event_base_ = nullptr;
TcpClient::TcpClient(const std::string &address, std::uint16_t port)
@ -43,7 +42,8 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
buffer_event_(nullptr),
server_address_(std::move(address)),
server_port_(port),
is_stop_(true) {
is_stop_(true),
is_connected_(false) {
message_handler_.SetCallback([this](const CommMessage &message) {
if (message_callback_) {
message_callback_(*this, message);
@ -55,12 +55,15 @@ TcpClient::~TcpClient() { Stop(); }
std::string TcpClient::GetServerAddress() const { return server_address_; }
void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
const OnTimeout &timeout) {
connected_callback_ = conn;
disconnected_callback_ = disconn;
read_callback_ = read;
timeout_callback_ = timeout;
void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; }
void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; }
bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
std::unique_lock<std::mutex> lock(connection_mutex_);
bool res =
connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); });
return res;
}
void TcpClient::Init() {
@ -68,6 +71,7 @@ void TcpClient::Init() {
if (buffer_event_) {
return;
}
is_stop_ = false;
if (!CommUtil::CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
}
@ -198,6 +202,12 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
}
}
void TcpClient::NotifyConnected() {
MS_LOG(INFO) << "Client connected to the server!";
is_connected_ = true;
connection_cond_.notify_all();
}
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(ptr);
@ -205,27 +215,24 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void
if (events & BEV_EVENT_CONNECTED) {
// Connected
if (tcp_client->connected_callback_) {
tcp_client->connected_callback_(*tcp_client);
tcp_client->connected_callback_();
}
evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev));
tcp_client->NotifyConnected();
evutil_socket_t fd = bufferevent_getfd(bev);
SetTcpNoDelay(fd);
MS_LOG(INFO) << "Client connected!";
} else if (events & BEV_EVENT_ERROR) {
MS_LOG(ERROR) << "Client connected error!";
if (tcp_client->disconnected_callback_) {
tcp_client->disconnected_callback_(*tcp_client, errno);
tcp_client->disconnected_callback_();
}
} else if (events & BEV_EVENT_EOF) {
MS_LOG(ERROR) << "Client connected end of file";
if (tcp_client->disconnected_callback_) {
tcp_client->disconnected_callback_(*tcp_client, 0);
}
}
}
void TcpClient::Start() {
MS_EXCEPTION_IF_NULL(event_base_);
is_stop_ = false;
int ret = event_base_dispatch(event_base_);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)

@ -30,19 +30,19 @@
#include <thread>
#include <mutex>
#include <atomic>
#include <condition_variable>
#include "ps/core/cluster_config.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
namespace mindspore {
namespace ps {
namespace core {
class TcpClient {
public:
using OnConnected = std::function<void(const TcpClient &)>;
using OnDisconnected = std::function<void(const TcpClient &, int)>;
using OnConnected = std::function<void()>;
using OnDisconnected = std::function<void()>;
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>;
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
@ -52,8 +52,9 @@ class TcpClient {
virtual ~TcpClient();
std::string GetServerAddress() const;
void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
const OnTimeout &timeout);
void set_disconnected_callback(const OnDisconnected &disconnected);
void set_connected_callback(const OnConnected &connected);
bool WaitConnected(const uint32_t &connected_timeout = ClusterConfig::cluster_available_timeout());
void Init();
void StartWithDelay(int seconds);
void Stop();
@ -73,6 +74,7 @@ class TcpClient {
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
virtual void OnReadHandler(const void *buf, size_t num);
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
void NotifyConnected();
private:
OnMessage message_callback_;
@ -86,12 +88,14 @@ class TcpClient {
static event_base *event_base_;
std::mutex connection_mutex_;
std::condition_variable connection_cond_;
event *event_timeout_;
bufferevent *buffer_event_;
std::string server_address_;
std::uint16_t server_port_;
std::atomic<bool> is_stop_;
std::atomic<bool> is_connected_;
};
} // namespace core

@ -95,6 +95,7 @@ void TcpServer::Init() {
MS_LOG(EXCEPTION) << "Use event pthread failed!";
}
is_stop_ = false;
base_ = event_base_new();
MS_EXCEPTION_IF_NULL(base_);
if (!CommUtil::CheckIp(server_address_)) {
@ -138,7 +139,6 @@ void TcpServer::Start() {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server!";
MS_EXCEPTION_IF_NULL(base_);
is_stop_ = false;
int ret = event_base_dispatch(base_);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
@ -368,6 +368,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); }
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
} // namespace core
} // namespace ps
} // namespace mindspore

@ -31,8 +31,8 @@ WorkerNode::~WorkerNode() {
}
}
client_to_scheduler_->StopEventBase();
if (worker_thread_->joinable()) {
worker_thread_->join();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
@ -43,7 +43,7 @@ WorkerNode::~WorkerNode() {
bool WorkerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Starting worker node!";
Initialize();
Register();
Register(client_to_scheduler_);
Heartbeat(client_to_scheduler_);
if (!WaitForStart(timeout)) {
@ -52,84 +52,25 @@ bool WorkerNode::Start(const uint32_t &timeout) {
}
MS_LOG(INFO) << "The node is ready to fetch servers!";
// If the cluster is ready to use, then Get the address of all the servers
if (!is_timeout_.load()) {
FetchServers(client_to_scheduler_);
MS_LOG(INFO) << "Fetch servers successful!";
MS_LOG(INFO) << "Worker node get all the servers address successful!";
}
MS_LOG(INFO) << "The Worker node has successfully started.";
return true;
}
void WorkerNode::Register() {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::REGISTER);
RegisterMessage register_message;
register_message.set_node_id(node_info_.node_id_);
register_message.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
if (!SendMessageSync(client_to_scheduler_, comm_message)) {
MS_LOG(EXCEPTION) << "Worker node register timeout!";
}
MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_
<< "is registering to scheduler, the request id is:" << message_meta.request_id();
}
void WorkerNode::ProcessRegisterResp(const CommMessage &message) {
RegisterRespMessage register_resp_message;
register_resp_message.ParseFromString(message.data());
if (register_resp_message.node_id() != node_info_.node_id_) {
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
<< " is not match the current node id:" << node_info_.node_id_;
}
node_info_.rank_id_ = register_resp_message.rank_id();
MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
}
void WorkerNode::Initialize() {
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = NodeRole::WORKER;
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_;
InitClientToScheduler();
}
void WorkerNode::InitClientToScheduler() {
std::string scheduler_host = ClusterConfig::scheduler_host();
uint16_t scheduler_port = ClusterConfig::scheduler_port();
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::HEARTBEAT:
ProcessHeartbeatResp(message);
break;
case NodeCommand::REGISTER:
ProcessRegisterResp(message);
break;
case NodeCommand::FETCH_SERVER:
ProcessFetchServersResp(message);
break;
case NodeCommand::FINISH:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
NotifyMessageArrival(message);
});
client_to_scheduler_->Init();
worker_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The worker node start a tcp client!";
client_to_scheduler_->Start();
});
worker_thread_->detach();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Worker node init client timeout!";
}
MS_LOG(INFO) << "Worker node init client successful!";
}
bool WorkerNode::Stop() {
@ -144,8 +85,8 @@ bool WorkerNode::Stop() {
}
}
client_to_scheduler_->StopEventBase();
if (worker_thread_->joinable()) {
worker_thread_->join();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
@ -165,23 +106,6 @@ bool WorkerNode::Finish(const uint32_t &timeout) {
is_already_finished_ = true;
return Disconnect(client_to_scheduler_, timeout);
}
bool WorkerNode::BroadcastToServers(const std::string &message) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message);
}
return Wait(request_id);
}
} // namespace core
} // namespace ps
} // namespace mindspore

@ -17,50 +17,35 @@
#ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_
#include <atomic>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <thread>
#include <unordered_map>
#include <utility>
#include <condition_variable>
#include <algorithm>
#include <tuple>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "ps/core/node.h"
#include "ps/core/abstract_node.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace core {
class WorkerNode : public Node {
class WorkerNode : public AbstractNode {
public:
WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {}
WorkerNode() = default;
~WorkerNode() override;
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
bool BroadcastToServers(const std::string &message);
private:
void Register();
void ProcessRegisterResp(const CommMessage &message);
void Initialize();
void InitClientToScheduler();
std::shared_ptr<TcpClient> client_to_scheduler_;
std::unique_ptr<std::thread> worker_thread_;
};
} // namespace core
} // namespace ps

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/embedding_table_shard_metadata.h"
namespace mindspore {
namespace ps {
uint64_t EmbeddingTableShardMetadata::begin() const { return begin_; }
uint64_t EmbeddingTableShardMetadata::end() const { return end_; }
uint64_t EmbeddingTableShardMetadata::size() const { return end_ - begin_; }
} // namespace ps
} // namespace mindspore

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_
#define MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_
#include <iostream>
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
class EmbeddingTableShardMetadata {
public:
explicit EmbeddingTableShardMetadata(uint64_t begin, uint64_t end) : begin_(begin), end_(end) {}
virtual ~EmbeddingTableShardMetadata() = default;
uint64_t begin() const;
uint64_t end() const;
uint64_t size() const;
private:
uint64_t begin_;
uint64_t end_;
};
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "ps/embedding_table_shard_metadata.h"
namespace mindspore {
namespace ps {
class TestEmbeddingTableShardMetadata : public UT::Common {
public:
TestEmbeddingTableShardMetadata() = default;
virtual ~TestEmbeddingTableShardMetadata() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestEmbeddingTableShardMetadata, EmbeddingTable) {
EmbeddingTableShardMetadata embedding_table_shard(1, 100);
EXPECT_EQ(embedding_table_shard.begin(), 1);
EXPECT_EQ(embedding_table_shard.end(), 100);
EXPECT_EQ(embedding_table_shard.size(), 99);
}
} // namespace ps
} // namespace mindspore
Loading…
Cancel
Save