parent
25171b454a
commit
2d2bf2d0ee
@ -0,0 +1,212 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/abstract_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::REGISTER);
|
||||
|
||||
RegisterMessage register_message;
|
||||
register_message.set_node_id(node_info_.node_id_);
|
||||
register_message.set_role(node_info_.node_role_);
|
||||
register_message.set_ip(node_info_.ip_);
|
||||
register_message.set_port(node_info_.port_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(register_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, comm_message)) {
|
||||
MS_LOG(EXCEPTION) << "Node register timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_
|
||||
<< "is registering to scheduler, the request id is:" << message_meta.request_id();
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
|
||||
RegisterRespMessage register_resp_message;
|
||||
register_resp_message.ParseFromString(message.data());
|
||||
if (register_resp_message.node_id() != node_info_.node_id_) {
|
||||
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
|
||||
<< " is not match the current node id:" << node_info_.node_id_;
|
||||
}
|
||||
|
||||
node_info_.rank_id_ = register_resp_message.rank_id();
|
||||
|
||||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
|
||||
}
|
||||
|
||||
bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
|
||||
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient((*it).first.second);
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_message) {
|
||||
on_node_event_message_ = on_node_event_message;
|
||||
}
|
||||
|
||||
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
|
||||
<< " begin send heartbeat to the scheduler!";
|
||||
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
|
||||
while (!is_finish_.load()) {
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::HEARTBEAT);
|
||||
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(heartbeat_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||
}
|
||||
}
|
||||
});
|
||||
heart_beat_thread_->detach();
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.ParseFromString(message.data());
|
||||
is_ready_ = heartbeat_resp_message.is_cluster_ready();
|
||||
if (is_ready_.load()) {
|
||||
wait_start_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
|
||||
}
|
||||
is_finish_ = heartbeat_resp_message.is_cluster_finish();
|
||||
if (is_finish_.load()) {
|
||||
wait_finish_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
|
||||
}
|
||||
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
|
||||
if (is_timeout_ && on_node_event_message_) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::FETCH_SERVER);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
|
||||
FetchServersRespMessage fetch_servers_resp_message;
|
||||
fetch_servers_resp_message.ParseFromString(message.data());
|
||||
|
||||
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
|
||||
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size();
|
||||
}
|
||||
|
||||
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::FINISH);
|
||||
|
||||
FinishMessage finish_message;
|
||||
finish_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(finish_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(EXCEPTION) << "Disconnect timeout!";
|
||||
}
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
|
||||
return WaitForDisconnect(timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||
bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
if (is_finish_.load()) {
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
|
||||
}
|
||||
return is_finish_.load();
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
bool AbstractNode::InitClientToScheduler() {
|
||||
std::string scheduler_host = ClusterConfig::scheduler_host();
|
||||
uint16_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
|
||||
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::HEARTBEAT:
|
||||
ProcessHeartbeatResp(message);
|
||||
break;
|
||||
case NodeCommand::REGISTER:
|
||||
ProcessRegisterResp(message);
|
||||
break;
|
||||
case NodeCommand::FETCH_SERVER:
|
||||
ProcessFetchServersResp(message);
|
||||
break;
|
||||
case NodeCommand::FINISH:
|
||||
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
});
|
||||
|
||||
client_to_scheduler_->Init();
|
||||
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
|
||||
MS_LOG(INFO) << "The worker node start a tcp client!";
|
||||
client_to_scheduler_->Start();
|
||||
});
|
||||
client_to_scheduler_thread_->detach();
|
||||
|
||||
client_to_scheduler_->set_disconnected_callback([&]() {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval()));
|
||||
client_to_scheduler_->Stop();
|
||||
client_to_scheduler_->Init();
|
||||
});
|
||||
return client_to_scheduler_->WaitConnected();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ps/core/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class AbstractNode : public Node {
|
||||
public:
|
||||
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
|
||||
~AbstractNode() override = default;
|
||||
|
||||
bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
|
||||
protected:
|
||||
void Register(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessRegisterResp(const CommMessage &message);
|
||||
void Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessHeartbeatResp(const CommMessage &message);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessFetchServersResp(const CommMessage &message);
|
||||
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
|
||||
bool WaitForDisconnect(const uint32_t &timeout);
|
||||
bool InitClientToScheduler();
|
||||
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||
std::shared_ptr<TcpClient> client_to_scheduler_;
|
||||
OnNodeEventMessage on_node_event_message_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
|
@ -0,0 +1,145 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ps/core/server_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
ServerNode::~ServerNode() {
|
||||
MS_LOG(INFO) << "Stop server node!";
|
||||
if (!is_already_stopped_.load()) {
|
||||
server_->Stop();
|
||||
client_to_scheduler_->Stop();
|
||||
client_to_scheduler_->StopEventBase();
|
||||
if (server_thread_->joinable()) {
|
||||
server_thread_->join();
|
||||
}
|
||||
if (client_to_scheduler_thread_->joinable()) {
|
||||
client_to_scheduler_thread_->join();
|
||||
}
|
||||
is_already_stopped_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool ServerNode::Start(const uint32_t &timeout) {
|
||||
MS_LOG(INFO) << "Start server node!";
|
||||
Initialize();
|
||||
Register(client_to_scheduler_);
|
||||
Heartbeat(client_to_scheduler_);
|
||||
|
||||
if (!WaitForStart(timeout)) {
|
||||
MS_LOG(EXCEPTION) << "Start Worker node timeout!";
|
||||
}
|
||||
MS_LOG(INFO) << "The cluster is ready to use!";
|
||||
|
||||
// If the cluster is ready to use, then Get the address of all the servers
|
||||
if (!is_timeout_.load()) {
|
||||
FetchServers(client_to_scheduler_);
|
||||
MS_LOG(INFO) << "Server node get all the servers address successful!";
|
||||
}
|
||||
MS_LOG(INFO) << "Start the node is successful!";
|
||||
return true;
|
||||
}
|
||||
|
||||
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
|
||||
|
||||
void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
|
||||
const std::string &message) {
|
||||
auto &meta = const_cast<MessageMeta &>(message_meta);
|
||||
meta.set_role(node_info_.node_role_);
|
||||
meta.set_rank_id(node_info_.rank_id_);
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {meta};
|
||||
comm_message.set_data(message);
|
||||
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
void ServerNode::CreateTcpServer() {
|
||||
std::string interface;
|
||||
std::string server_ip;
|
||||
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
|
||||
server_ = std::make_shared<TcpServer>(server_ip, 0);
|
||||
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendData(server, conn, message);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
});
|
||||
server_->Init();
|
||||
server_thread_ = std::make_unique<std::thread>([&]() {
|
||||
MS_LOG(INFO) << "The server node start a tcp server!";
|
||||
server_->Start();
|
||||
});
|
||||
server_thread_->detach();
|
||||
}
|
||||
|
||||
void ServerNode::Initialize() {
|
||||
CreateTcpServer();
|
||||
is_already_stopped_ = false;
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = NodeRole::SERVER;
|
||||
node_info_.ip_ = server_->BoundIp();
|
||||
node_info_.port_ = server_->BoundPort();
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " is generate uuid is:" << node_info_.node_id_;
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Server node init client timeout!";
|
||||
}
|
||||
MS_LOG(INFO) << "Server node init client successful!";
|
||||
}
|
||||
|
||||
void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
if (request_handler_) {
|
||||
request_handler_(server, conn, message.pb_meta(), message.data());
|
||||
}
|
||||
}
|
||||
|
||||
bool ServerNode::Stop() {
|
||||
MS_LOG(INFO) << "Stop server node!";
|
||||
if (!is_already_stopped_.load()) {
|
||||
server_->Stop();
|
||||
client_to_scheduler_->Stop();
|
||||
client_to_scheduler_->StopEventBase();
|
||||
if (server_thread_->joinable()) {
|
||||
server_thread_->join();
|
||||
}
|
||||
if (client_to_scheduler_thread_->joinable()) {
|
||||
client_to_scheduler_thread_->join();
|
||||
}
|
||||
if (heart_beat_thread_->joinable()) {
|
||||
heart_beat_thread_->join();
|
||||
}
|
||||
is_already_stopped_ = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ServerNode::Finish(const uint32_t &timeout) {
|
||||
std::lock_guard<std::mutex> lock(finish_mutex_);
|
||||
if (is_already_finished_) {
|
||||
MS_LOG(INFO) << "Server node already finish!";
|
||||
return true;
|
||||
}
|
||||
is_already_finished_ = true;
|
||||
return Disconnect(client_to_scheduler_, timeout);
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class ServerNode : public AbstractNode {
|
||||
public:
|
||||
ServerNode() : server_(nullptr), server_thread_(nullptr) {}
|
||||
~ServerNode() override;
|
||||
|
||||
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn,
|
||||
const MessageMeta message_meta, const std::string &message)>;
|
||||
|
||||
void set_handler(const RequestHandler &handler);
|
||||
void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
|
||||
const std::string &message);
|
||||
|
||||
private:
|
||||
void CreateTcpServer();
|
||||
void Initialize();
|
||||
void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
RequestHandler request_handler_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_
|
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/embedding_table_shard_metadata.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
uint64_t EmbeddingTableShardMetadata::begin() const { return begin_; }
|
||||
|
||||
uint64_t EmbeddingTableShardMetadata::end() const { return end_; }
|
||||
|
||||
uint64_t EmbeddingTableShardMetadata::size() const { return end_ - begin_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_
|
||||
#define MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_
|
||||
|
||||
#include <iostream>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
class EmbeddingTableShardMetadata {
|
||||
public:
|
||||
explicit EmbeddingTableShardMetadata(uint64_t begin, uint64_t end) : begin_(begin), end_(end) {}
|
||||
virtual ~EmbeddingTableShardMetadata() = default;
|
||||
|
||||
uint64_t begin() const;
|
||||
uint64_t end() const;
|
||||
uint64_t size() const;
|
||||
|
||||
private:
|
||||
uint64_t begin_;
|
||||
uint64_t end_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_EMBEDDING_TABLE_SHARD_METADATA_H_
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ps/embedding_table_shard_metadata.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
class TestEmbeddingTableShardMetadata : public UT::Common {
|
||||
public:
|
||||
TestEmbeddingTableShardMetadata() = default;
|
||||
virtual ~TestEmbeddingTableShardMetadata() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestEmbeddingTableShardMetadata, EmbeddingTable) {
|
||||
EmbeddingTableShardMetadata embedding_table_shard(1, 100);
|
||||
EXPECT_EQ(embedding_table_shard.begin(), 1);
|
||||
EXPECT_EQ(embedding_table_shard.end(), 100);
|
||||
EXPECT_EQ(embedding_table_shard.size(), 99);
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
Loading…
Reference in new issue