pull/5930/head
Lixia Chen 5 years ago
parent 88ded11f59
commit 983827ec5c

@ -36,6 +36,7 @@ include(CPack)
set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries")
set(INSTALL_PY_DIR ".")
set(INSTALL_BASE_DIR ".")
set(INSTALL_BIN_DIR "bin")
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(INSTALL_LIB_DIR ".")
@ -78,7 +79,14 @@ if (ENABLE_MINDDATA)
DESTINATION ${INSTALL_BASE_DIR}
COMPONENT mindspore
)
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
install(
TARGETS cache_admin cache_server
OPTIONAL
DESTINATION ${INSTALL_BIN_DIR}
COMPONENT mindspore
)
endif()
file(GLOB_RECURSE OPENCV_LIB_LIST
${opencv_LIBPATH}/libopencv_core*
${opencv_LIBPATH}/libopencv_imgcodecs*

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <optional>
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/cache/cache_client.h"
@ -22,17 +23,19 @@ namespace dataset {
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(
py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname,
std::optional<int32_t> port, int32_t prefetch_sz) {
std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz);
if (hostname) builder.SetHostname(hostname.value());
if (port) builder.SetPort(port.value());
THROW_IF_ERROR(builder.Build(&cc));
return cc;
}))
.def(py::init([](session_id_type id, uint64_t mem_sz, bool spill,
std::optional<std::string> hostname, std::optional<int32_t> port,
std::optional<int32_t> num_connections, std::optional<int32_t> prefetch_sz) {
std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill);
if (hostname) builder.SetHostname(hostname.value());
if (port) builder.SetPort(port.value());
if (num_connections) builder.SetNumConnections(num_connections.value());
if (prefetch_sz) builder.SetPrefetchSize(prefetch_sz.value());
THROW_IF_ERROR(builder.Build(&cc));
return cc;
}))
.def("GetStat", [](CacheClient &cc) {
CacheServiceStat stat{};
THROW_IF_ERROR(cc.GetStat(&stat));

@ -18,6 +18,7 @@
#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include "mindspore/core/utils/log_adapter.h"
#include "minddata/dataset/util/system_pool.h"
@ -33,7 +34,9 @@ ConfigManager::ConfigManager()
monitor_sampling_interval_(kCfgMonitorSamplingInterval),
callback_timout_(kCfgCallbackTimeout),
cache_host_(kCfgDefaultCacheHost),
cache_port_(kCfgDefaultCachePort) {
cache_port_(kCfgDefaultCachePort),
num_connections_(kDftNumConnections),
prefetch_size_(kDftPrefetchSize) {
auto env_cache_host = std::getenv("MS_CACHE_HOST");
auto env_cache_port = std::getenv("MS_CACHE_PORT");
if (env_cache_host != nullptr) {
@ -71,6 +74,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
set_cache_host(j.value("cacheHost", cache_host_));
set_cache_port(j.value("cachePort", cache_port_));
set_num_connections(j.value("numConnections", num_connections_));
set_prefetch_size(j.value("prefetchSize", prefetch_size_));
return Status::OK();
}
@ -120,8 +125,12 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s
void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; }
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = std::move(cache_host); }
void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; }
void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; }
void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; }
} // namespace dataset
} // namespace mindspore

@ -97,6 +97,14 @@ class ConfigManager {
// @return The port of cache server
int32_t cache_port() const { return cache_port_; }
/// getter function
/// \return Number of tcp/ip connection
int32_t num_connections() const { return num_connections_; }
/// getter function
/// \return Prefetch size
int32_t prefetch_size() const { return prefetch_size_; }
// setter function
// @param rows_per_buffer - The setting to apply to the config
void set_rows_per_buffer(int32_t rows_per_buffer);
@ -121,6 +129,14 @@ class ConfigManager {
// @param cache_port - The port of cache server
void set_cache_port(int32_t cache_port);
/// setter function
/// \param num_connections
void set_num_connections(int32_t num_connections);
/// setter function
/// \param prefetch_size
void set_prefetch_size(int32_t prefetch_size);
uint32_t seed() const;
// setter function
@ -153,6 +169,8 @@ class ConfigManager {
uint32_t callback_timout_;
std::string cache_host_;
int32_t cache_port_;
int32_t num_connections_;
int32_t prefetch_size_;
// Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info

@ -71,6 +71,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
constexpr int32_t kCfgDefaultCachePort = 50052;
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
constexpr int32_t kDftPrefetchSize = 20;
constexpr int32_t kDftNumConnections = 12;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;

@ -79,6 +79,14 @@ class TensorRow {
const vector_type &getRow() const { return row_; }
int64_t SizeInBytes() const {
size_t sz = 0;
for (auto &it : row_) {
sz += it->SizeInBytes();
}
return sz;
}
// Wrapper functions to support vector operations
void emplace_back(value_type t) { row_.emplace_back(t); }

@ -12,7 +12,9 @@ add_library(engine-cache-client OBJECT
if (ENABLE_CACHE)
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc)
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS}
cache_grpc_client.cc
cache_ipc.cc)
add_library(engine-cache-server OBJECT
${CACHE_GRPC_SRCS}

@ -37,12 +37,17 @@ int main(int argc, char **argv) {
warningMsg += "WARNING:\n";
warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research";
warningMsg += " purposes at this time.\n";
warningMsg += "This command is currently disabled. Quitting.\n";
auto env_enable_cache = std::getenv("MS_ENABLE_CACHE");
if (env_enable_cache == nullptr || strcmp(env_enable_cache, "TRUE") != 0) {
// temporary disable cache feature in the current release
warningMsg += "This command is currently disabled. Quitting.\n";
std::cerr << warningMsg << std::endl;
return 0;
}
warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n";
// A warning message until the code is mature enough.
std::cerr << warningMsg << std::endl;
// temporary disable cache feature in the current release
return 0;
if (argc == 1) {
args.Help();

File diff suppressed because it is too large Load Diff

@ -32,6 +32,7 @@ class CacheAdminArgHandler {
static constexpr int32_t kDefaultNumWorkers = 32;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static constexpr float kMemoryCapRatio = 0.8;
static const char kServerBinary[];
static const char kDefaultSpillDir[];
@ -42,12 +43,13 @@ class CacheAdminArgHandler {
kCmdStop = 2,
kCmdGenerateSession = 3,
kCmdDestroySession = 4,
kCmdListSessions = 5,
kCmdUnknown = 32767
};
CacheAdminArgHandler();
~CacheAdminArgHandler() = default;
virtual ~CacheAdminArgHandler();
Status ParseArgStream(std::stringstream *arg_stream);
@ -70,12 +72,12 @@ class CacheAdminArgHandler {
kArgNumWorkers = 9,
kArgSharedMemorySize = 10,
kArgLogLevel = 11,
kArgNumArgs = 12 // Must be the last position to provide a count
kArgMemoryCapRatio = 12,
kArgListSessions = 13,
kArgNumArgs = 14 // Must be the last position to provide a count
};
Status StartServer();
Status StopServer();
Status StartStopServer(CommandId);
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
@ -83,6 +85,9 @@ class CacheAdminArgHandler {
Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status Validate();
CommandId command_id_;
@ -90,6 +95,7 @@ class CacheAdminArgHandler {
int32_t num_workers_;
int32_t shm_mem_sz_;
int32_t log_level_;
float memory_cap_ratio_;
session_id_type session_id_;
std::string hostname_;
std::string spill_dir_;

@ -17,27 +17,19 @@
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
: ptr_(nullptr), val_in_GB_(val_in_GB), port_(port), shmid_(-1) {}
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) {
// We create the shared memory and we will destroy it. All other client just detach only.
shm_.RemoveResourcesOnExit();
}
CachedSharedMemoryArena::~CachedSharedMemoryArena() {
#if CACHE_LOCAL_CLIENT
if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) {
shmdt(this->ptr_);
}
this->ptr_ = nullptr;
if (shmid_ != -1) {
shmctl(shmid_, IPC_RMID, nullptr);
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
#endif
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
#if CACHE_LOCAL_CLIENT
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
if (ba == nullptr) {
return Status(StatusCode::kOutOfMemory);
@ -46,26 +38,13 @@ Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryAr
// the destructor of *out to deal.
(*out).reset(ba);
// Generate the ftok using a combination of port.
int err;
auto shm_key = PortToFtok(port, &err);
if (shm_key == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
SharedMemory::shm_key_t shm_key;
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
ba->shm_.SetPublicKey(shm_key);
// Value is in GB. Convert into bytes.
int64_t sz = val_in_GB * 1073741824L;
ba->shmid_ = shmget(shm_key, sz, IPC_CREAT | IPC_EXCL | access_mode);
if (ba->shmid_) {
ba->ptr_ = shmat(ba->shmid_, nullptr, 0);
if (ba->ptr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
ba->impl_ = std::make_unique<ArenaImpl>(ba->ptr_, sz);
} else {
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
}
#endif
RETURN_IF_NOT_OK(ba->shm_.Create(sz));
ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz);
return Status::OK();
}
} // namespace dataset

@ -21,6 +21,7 @@
#include <string>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
namespace mindspore {
namespace dataset {
/// This is a derived class of Arena but resides in shared memory
@ -73,10 +74,9 @@ class CachedSharedMemoryArena : public MemoryPool {
private:
mutable std::mutex mux_;
void *ptr_;
int32_t val_in_GB_;
int32_t port_;
int shmid_;
SharedMemory shm_;
std::unique_ptr<ArenaImpl> impl_;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);

@ -24,26 +24,26 @@
namespace mindspore {
namespace dataset {
CacheClient::Builder::Builder()
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) {
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_connections_(0), prefetch_size_(0) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
hostname_ = cfg->cache_host();
port_ = cfg->cache_port();
num_workers_ = cfg->num_parallel_workers();
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
num_connections_ = cfg->num_connections(); // number of async tcp/ip connections
prefetch_size_ = cfg->prefetch_size(); // prefetch size
}
Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_IF_NOT_OK(SanityCheck());
*out =
std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_);
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_connections_,
prefetch_size_);
return Status::OK();
}
Status CacheClient::Builder::SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
@ -55,26 +55,32 @@ Status CacheClient::Builder::SanityCheck() {
// Constructor
CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,
int32_t port, int32_t num_workers, int32_t prefetch_size)
int32_t port, int32_t num_connections, int32_t prefetch_size)
: server_connection_id_(0),
cache_mem_sz_(cache_mem_sz),
spill_(spill),
local_bypass_(false),
hostname_(std::move(hostname)),
port_(port),
num_workers_(num_workers),
prefetch_size_(prefetch_size) {
num_connections_(num_connections),
prefetch_size_(prefetch_size),
fetch_all_keys_(true) {
cinfo_.set_session_id(session_id);
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_);
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_connections_);
}
CacheClient::~CacheClient() {
cache_miss_keys_wp_.Set();
(void)comm_->ServiceStop();
}
// print method for display cache details
void CacheClient::Print(std::ostream &out) const {
out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc()
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz()
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname()
<< "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers()
<< "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << GetCacheMemSz()
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << GetHostname()
<< "\n Port: " << GetPort() << "\n Number of rpc workers: " << GetNumConnections()
<< "\n Prefetch size: " << GetPrefetchSize() << "\n Local client support: " << std::boolalpha
<< SupportLocalClient();
}
@ -199,14 +205,6 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
return Status::OK();
}
Status CacheClient::PurgeCache() {
UniqueLock lck(&mux_);
auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}
Status CacheClient::DestroyCache() {
UniqueLock lck(&mux_);
auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_);
@ -253,5 +251,71 @@ Status CacheClient::BuildPhaseDone() const {
}
Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); }
void CacheClient::ServerRunningOutOfResources() {
bool expected = true;
if (fetch_all_keys_.compare_exchange_strong(expected, false)) {
Status rc;
// Server runs out of memory or disk space to cache any more rows.
// First of all, we will turn off the locking.
auto toggle_write_mode_rq = std::make_shared<ToggleWriteModeRequest>(server_connection_id_, false);
rc = PushRequest(toggle_write_mode_rq);
if (rc.IsError()) {
return;
}
// Wait until we can toggle the state of the server to non-locking
rc = toggle_write_mode_rq->Wait();
if (rc.IsError()) {
return;
}
// Now we get a list of all the keys not cached at the server so
// we can filter out at the prefetch level.
auto cache_miss_rq = std::make_shared<GetCacheMissKeysRequest>(server_connection_id_);
rc = PushRequest(cache_miss_rq);
if (rc.IsError()) {
return;
}
rc = cache_miss_rq->Wait();
if (rc.IsError()) {
return;
}
// We will get back a vector of row id between [min,max] that are absent in the server.
auto &row_id_buf = cache_miss_rq->reply_.result();
auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
std::vector<row_id_type> row_ids;
auto sz = p->row_id()->size();
row_ids.reserve(sz);
for (auto i = 0; i < sz; ++i) {
row_ids.push_back(p->row_id()->Get(i));
}
cache_miss_keys_ = std::make_unique<CacheMissKeys>(row_ids);
// We are all set.
cache_miss_keys_wp_.Set();
}
}
CacheClient::CacheMissKeys::CacheMissKeys(const std::vector<row_id_type> &v) {
auto it = v.begin();
min_ = *it;
++it;
max_ = *it;
++it;
while (it != v.end()) {
gap_.insert(*it);
++it;
}
MS_LOG(WARNING) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size();
}
bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
if (key > max_ || key < min_) {
return true;
} else if (key == min_ || key == max_) {
return false;
} else {
auto it = gap_.find(key);
return it != gap_.end();
}
}
} // namespace dataset
} // namespace mindspore

@ -16,8 +16,13 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
#include <atomic>
#include <iostream>
#include <limits>
#include <memory>
#include <map>
#include <mutex>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
@ -31,6 +36,8 @@
#endif
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/util/lock.h"
#include "minddata/dataset/util/cond_var.h"
#include "minddata/dataset/util/queue_map.h"
namespace mindspore {
namespace dataset {
@ -89,10 +96,10 @@ class CacheClient {
}
/// Setter function to set number of async rpc workers
/// \param num_workers
/// \param num_connections
/// \return Builder object itself
Builder &SetNumWorkers(int32_t num_workers) {
num_workers_ = num_workers;
Builder &SetNumConnections(int32_t num_connections) {
num_connections_ = num_connections;
return *this;
}
@ -105,13 +112,13 @@ class CacheClient {
}
/// Getter functions
session_id_type getSessionId() const { return session_id_; }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
session_id_type GetSessionId() const { return session_id_; }
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
Status SanityCheck();
@ -123,7 +130,7 @@ class CacheClient {
bool spill_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t num_connections_;
int32_t prefetch_size_;
};
@ -132,10 +139,10 @@ class CacheClient {
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port,
int32_t num_workers, int32_t prefetch_size);
int32_t num_connections, int32_t prefetch_size);
/// \brief Destructor
~CacheClient() { (void)comm_->ServiceStop(); }
~CacheClient();
/// \brief Send a TensorRow to the cache server
/// \param[in] row
@ -161,10 +168,6 @@ class CacheClient {
/// \return Status object
Status CreateCache(uint32_t tree_crc, bool generate_id);
/// \brief Purge a cache. Cache can be reused after reset.
/// \return Status object
Status PurgeCache();
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
/// \return Status object
Status DestroyCache();
@ -218,12 +221,31 @@ class CacheClient {
/// Getter functions
session_id_type session_id() const { return cinfo_.session_id(); }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
const std::string &GetHostname() const { return hostname_; }
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
/// MergeOp will notify us when the server can't cache any more rows.
/// We will stop any attempt to fetch any rows that are most likely
/// not present at the server.
void ServerRunningOutOfResources();
/// \brief Check if a row is 100% cache miss at the server by checking the local information
/// \param key row id to be test
/// \return true if not at the server
bool KeyIsCacheMiss(row_id_type key) {
if (cache_miss_keys_) {
// Make sure it is fully built even though the pointer is not null
Status rc = cache_miss_keys_wp_.Wait();
if (rc.IsOk()) {
return cache_miss_keys_->KeyIsCacheMiss(key);
}
}
return false;
}
private:
mutable RWLock mux_;
@ -240,9 +262,27 @@ class CacheClient {
bool local_bypass_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t num_connections_;
int32_t prefetch_size_;
mutable std::shared_ptr<CacheClientGreeter> comm_;
std::atomic<bool> fetch_all_keys_;
WaitPost cache_miss_keys_wp_;
/// A structure shared by all the prefetchers to know what keys are missing at the server.
class CacheMissKeys {
public:
explicit CacheMissKeys(const std::vector<row_id_type> &v);
~CacheMissKeys() = default;
/// This checks if a key is missing.
/// \param key
/// \return true if definitely a key miss
bool KeyIsCacheMiss(row_id_type key);
private:
row_id_type min_;
row_id_type max_;
std::set<row_id_type> gap_;
};
std::unique_ptr<CacheMissKeys> cache_miss_keys_;
};
} // namespace dataset
} // namespace mindspore

@ -25,13 +25,6 @@
#define CACHE_LOCAL_CLIENT 1
#endif
#ifdef CACHE_LOCAL_CLIENT
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#else
typedef int key_t;
#endif
#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
@ -54,6 +47,8 @@ constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
constexpr static uint32_t kDataIsInSharedMemory = 2;
/// \brief Size of each message used in message queue.
constexpr static int32_t kSharedMessageSize = 2048;
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
@ -62,29 +57,10 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
reply->set_rc(static_cast<int32_t>(rc.get_code()));
reply->set_msg(rc.ToString());
}
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
/// \param port
/// \return unix socket url
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }
/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
inline key_t PortToFtok(int port, int *err) {
key_t shmkey = -1;
#ifdef CACHE_LOCAL_CLIENT
const std::string unix_path = PortToUnixSocketPath(port);
shmkey = ftok(unix_path.data(), 'a');
if (err != nullptr && shmkey == (key_t)-1) {
*err = errno;
}
#endif
return shmkey;
}
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_

@ -17,34 +17,10 @@
#include <chrono>
namespace mindspore {
namespace dataset {
Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag) {
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK(tag->base_rq_->Prepare());
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq);
tag->rpc_->StartCall();
// Last step is we release the ownership and transfer it to the completion queue.
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
auto ccReqTag = tag.release();
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_,
ccReqTag); // inject this object into the completion queue
return Status::OK();
}
CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); }
CacheClientGreeter::~CacheClientGreeter() {
(void)ServiceStop();
// Detach from shared memory if any
if (shmat_addr_ != nullptr) {
shmdt(shmat_addr_);
shmat_addr_ = nullptr;
}
}
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers)
: num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) {
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections)
: num_connections_(num_connections), request_cnt_(0) {
grpc::ChannelArguments args;
// We need to bump up the message size to unlimited. The default receiving
// message limit is 4MB which is not big enough.
@ -68,21 +44,11 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port
Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) {
*local_bypass = false;
#if CACHE_LOCAL_CLIENT
int err;
shm_key_ = PortToFtok(port, &err);
if (shm_key_ == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
SharedMemory::shm_key_t shm_key;
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
// Attach to the shared memory
shm_id_ = shmget(shm_key_, 0, 0);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
}
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
mem_.SetPublicKey(shm_key);
RETURN_IF_NOT_OK(mem_.Attach());
*local_bypass = true;
#endif
return Status::OK();
@ -90,7 +56,7 @@ Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass
Status CacheClientGreeter::DoServiceStart() {
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(DispatchWorkers(num_workers_));
RETURN_IF_NOT_OK(DispatchWorkers(num_connections_));
return Status::OK();
}
@ -100,19 +66,40 @@ Status CacheClientGreeter::DoServiceStop() {
// Shutdown the TaskGroup.
vg_.interrupt_all();
vg_.join_all(Task::WaitFlag::kNonBlocking);
// Drain the queue
bool success;
void *tag;
while (cq_.Next(&tag, &success)) {
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
delete r;
// Drain the queue. We know how many requests we send out
while (!req_.empty()) {
bool success;
void *tag;
while (cq_.Next(&tag, &success)) {
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
req_.erase(r->seqNo_);
}
}
return Status::OK();
}
Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq));
return tag->MakeCall(stub_.get(), &cq_, std::move(tag));
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK(rq->Prepare());
auto seqNo = request_cnt_.fetch_add(1);
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seqNo);
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_);
tag->rpc_->StartCall();
auto ccReqTag = tag.get();
// Insert it into the map.
{
std::unique_lock<std::mutex> lck(mux_);
auto r = req_.emplace(seqNo, std::move(tag));
if (!r.second) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__);
}
}
// Last step is to tag the request.
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag);
return Status::OK();
}
Status CacheClientGreeter::WorkerEntry() {
@ -129,15 +116,26 @@ Status CacheClientGreeter::WorkerEntry() {
auto &rc = rq->rc_;
if (!rc.ok()) {
auto error_code = rq->rc_.error_code();
std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
std::string err_msg;
if (error_code == grpc::StatusCode::UNAVAILABLE) {
err_msg =
"Cache server is unreachable. Make sure the server is running. GRPC Code" + std::to_string(error_code);
} else {
err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
}
Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg);
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
}
// Notify the waiting thread.
rq->Notify();
}
// We can now free the memory
delete rq;
{
// We can now free the memory
std::unique_lock<std::mutex> lck(mux_);
auto seqNo = rq->seqNo_;
auto n = req_.erase(seqNo);
CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seqNo) + " not found");
}
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();

@ -16,10 +16,14 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
@ -34,16 +38,10 @@ namespace dataset {
class CacheClientRequestTag {
public:
friend class CacheClientGreeter;
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {}
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq, int64_t seqNo)
: base_rq_(std::move(rq)), seqNo_(seqNo) {}
~CacheClientRequestTag() = default;
/// \brief Make a RPC call
/// \param stub from CacheClientGreeter
/// \param cq from CacheClientGreeter
/// \return Status object
static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag);
/// \brief Notify the client that a result has come back from the server
void Notify() { base_rq_->wp_.Set(); }
@ -52,6 +50,7 @@ class CacheClientRequestTag {
grpc::Status rc_;
grpc::ClientContext ctx_;
std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_;
int64_t seqNo_;
};
/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
@ -60,7 +59,7 @@ class CacheClientGreeter : public Service {
friend class CacheClient;
public:
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers);
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections);
~CacheClientGreeter();
/// Override base Service class
@ -85,17 +84,18 @@ class CacheClientGreeter : public Service {
/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); }
private:
std::shared_ptr<grpc::Channel> channel_;
std::unique_ptr<CacheServerGreeter::Stub> stub_;
grpc::CompletionQueue cq_;
TaskGroup vg_;
int32_t num_workers_;
key_t shm_key_;
int32_t shm_id_;
void *shmat_addr_;
int32_t num_connections_;
std::atomic<int64_t> request_cnt_;
mutable std::mutex mux_;
std::map<int64_t, std::unique_ptr<CacheClientRequestTag>> req_;
SharedMemory mem_;
};
} // namespace dataset
} // namespace mindspore

@ -47,53 +47,10 @@ void CacheServerGreeterImpl::Shutdown() {
CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); }
Status CacheServerGreeterImpl::IpcResourceCleanup() {
#if CACHE_LOCAL_CLIENT
int err;
auto shm_key = PortToFtok(port_, &err);
// We are expecting the unix path doesn't exist.
if (shm_key == (key_t)-1) {
return Status::OK();
}
// Attach to the shared memory
auto shm_id = shmget(shm_key, 0, 0);
if (shm_id == -1) {
return Status::OK();
}
struct shmid_ds ds {};
auto inx = shmctl(shm_id, IPC_STAT, &ds);
if (inx == -1) {
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (ds.shm_nattch == 0) {
// Stale shared memory from last time.
// Remove both the memory and the socket path
inx = shmctl(shm_id, IPC_RMID, nullptr);
if (inx == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
Path p(unix_socket_);
(void)p.Remove();
} else {
// Server is already up.
MS_LOG(ERROR) << "Cache server is already up and running";
// We return a duplicate error. The main() will intercept
// and output a proper message
return Status(StatusCode::kDuplicateKey);
}
#endif
return Status::OK();
}
Status CacheServerGreeterImpl::Run() {
// To listen on all interfaces, use 0.0.0.0
// Use 127.0.0.1 if just locally on the same machine.
std::string host("0.0.0.0"); // listen on all interfaces.
// Future, allow the user to choose listening interface. For now, default to localhost
std::string host("127.0.0.1");
std::string server_address = host + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
@ -101,9 +58,6 @@ Status CacheServerGreeterImpl::Run() {
int port_tcpip = 0;
#if CACHE_LOCAL_CLIENT
int port_local = 0;
// Check if we need to do clean up on the shared memory if the server
// came down unexpectedly like SEGV
RETURN_IF_NOT_OK(IpcResourceCleanup());
// We also optimize on local clients on the same machine using unix socket
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
#endif

@ -41,7 +41,7 @@ class CacheServerRequest : public BaseRequest {
st_(STATE::CREATE),
responder_(&ctx_) {}
~CacheServerRequest() = default;
~CacheServerRequest() override = default;
/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
/// functor will translate each protobuf into some form understood by by CacheService class.
@ -87,8 +87,6 @@ class CacheServerGreeterImpl final {
void Shutdown();
Status IpcResourceCleanup();
private:
int32_t port_;
size_t shm_pool_sz_in_gb_;

@ -0,0 +1,163 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include <sys/stat.h>
namespace mindspore {
namespace dataset {
Status PortToFtok(int port, SharedMemory::shm_key_t *out) {
RETURN_UNEXPECTED_IF_NULL(out);
key_t shmkey = -1;
const std::string unix_path = PortToUnixSocketPath(port);
shmkey = ftok(unix_path.data(), 'a');
if (shmkey == (key_t)-1) {
std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno);
return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg);
}
*out = shmkey;
return Status::OK();
}
SharedMessage::~SharedMessage() {
// Only remove the queue if we are asked to.
if (remove_ipc_on_exit_ && msg_qid_ != -1) {
// Remove the message que and never mind about the return code.
(void)msgctl(msg_qid_, IPC_RMID, nullptr);
msg_qid_ = -1;
}
}
Status SharedMessage::Create() {
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ == -1, "Message queue already created");
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
msg_qid_ = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode);
if (msg_qid_ == -1) {
std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status SharedMessage::SendStatus(const Status &rc) {
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
StatusMsgBuf msg{
1,
};
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size());
CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err));
msg.body.status.err_msg[rc.ToString().size()] = '\0';
err = msgsnd(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), IPC_NOWAIT);
if (err == -1) {
std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status SharedMessage::ReceiveStatus(Status *rc) {
RETURN_UNEXPECTED_IF_NULL(rc);
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
struct StatusMsgBuf msg {};
auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR);
if (err == -1) {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
Status rc_recv(static_cast<StatusCode>(msg.body.status.err_code), msg.body.status.err_msg);
*rc = std::move(rc_recv);
return Status::OK();
}
SharedMemory::~SharedMemory() {
if (shmat_addr_) {
(void)Detach();
}
if (remove_ipc_on_exit_ && shm_id_ != -1) {
// Remove the shared memory and never mind about the return code.
Status rc = Destroy();
if (rc.IsError()) {
MS_LOG(ERROR) << rc.ToString();
}
}
shm_id_ = -1;
shmat_addr_ = nullptr;
}
Status SharedMemory::Create(int64_t sz) {
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
shm_id_ = shmget(shm_key_, sz, IPC_CREAT | IPC_EXCL | access_mode);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
} else {
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
}
return Status::OK();
}
Status SharedMemory::Attach() {
shm_id_ = shmget(shm_key_, 0, 0);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
}
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
return Status::OK();
}
Status SharedMemory::Detach() {
if (shmat_addr_) {
auto err = shmdt(shmat_addr_);
if (err == -1) {
RETURN_STATUS_UNEXPECTED("Shared memory detach failed. Errno " + std::to_string(errno));
}
}
shmat_addr_ = nullptr;
return Status::OK();
}
Status SharedMemory::Destroy() {
// Remove the shared memory and never mind about the return code.
auto err = shmctl(shm_id_, IPC_RMID, nullptr);
if (err == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status SharedMemory::GetNumAttached(int32_t *num) {
RETURN_UNEXPECTED_IF_NULL(num);
struct shmid_ds ds {};
auto err = shmctl(shm_id_, IPC_STAT, &ds);
if (err == -1) {
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id_);
errMsg += "\nPlease remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
*num = ds.shm_nattch;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,207 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/msg.h>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// A message queue structure between the parent and the child process
struct StatusMsgBuf {
int64_t mtype;
union {
char mtext[1];
struct {
int32_t err_code;
char err_msg[kSharedMessageSize];
} status;
} body;
};
class BaseIPC {
public:
BaseIPC() : remove_ipc_on_exit_(false) {}
virtual ~BaseIPC() {}
/// Indicate if we should remove the ipc resource on exit. Usually this is done by parent process.
void RemoveResourcesOnExit() { remove_ipc_on_exit_ = true; }
/// Copy constructors
BaseIPC(const BaseIPC &rhs) : remove_ipc_on_exit_(false) {}
BaseIPC &operator=(const BaseIPC &rhs) {
if (&rhs != this) {
remove_ipc_on_exit_ = false;
}
return *this;
}
/// Move constructors
BaseIPC(BaseIPC &&rhs) noexcept : remove_ipc_on_exit_(rhs.remove_ipc_on_exit_) { rhs.remove_ipc_on_exit_ = false; }
BaseIPC &operator=(BaseIPC &&rhs) noexcept {
if (&rhs != this) {
remove_ipc_on_exit_ = rhs.remove_ipc_on_exit_;
rhs.remove_ipc_on_exit_ = false;
}
return *this;
}
protected:
bool remove_ipc_on_exit_;
};
/// \brief This wraps a shared message for the communication between processes. It is used primarily
/// for starting and stopping a server.
class SharedMessage : public BaseIPC {
public:
using queue_id_t = int;
SharedMessage() : msg_qid_(-1) {}
explicit SharedMessage(queue_id_t qid) : msg_qid_(qid) {}
~SharedMessage() override;
/// Copy constructors
SharedMessage(const SharedMessage &rhs) : BaseIPC(rhs), msg_qid_(rhs.msg_qid_) {}
SharedMessage &operator=(const SharedMessage &rhs) {
if (&rhs != this) {
msg_qid_ = rhs.msg_qid_;
BaseIPC::operator=(rhs);
}
return *this;
}
/// Move constructors
SharedMessage(SharedMessage &&rhs) noexcept : BaseIPC(std::move(rhs)) {
msg_qid_ = rhs.msg_qid_;
rhs.msg_qid_ = -1;
}
SharedMessage &operator=(SharedMessage &&rhs) noexcept {
if (&rhs != this) {
msg_qid_ = rhs.msg_qid_;
rhs.msg_qid_ = -1;
BaseIPC::operator=(std::move(rhs));
}
return *this;
}
/// Return the private id
queue_id_t GetMsgQueueId() const { return msg_qid_; }
/// \brief Create a private message queue
Status Create();
/// Send a Status object
Status SendStatus(const Status &rc);
/// Retrieve a Status object
Status ReceiveStatus(Status *rc);
private:
queue_id_t msg_qid_;
};
/// \brief This wraps a shared memory for the communication between processes. It is used primarily
/// for transporting large tensor rows.
class SharedMemory : public BaseIPC {
public:
using shm_key_t = int;
using shm_id_t = int;
SharedMemory() : shm_id_(-1), shm_key_(-1), shmat_addr_(nullptr) {}
explicit SharedMemory(shm_key_t public_key) : shm_id_(-1), shm_key_(public_key), shmat_addr_(nullptr) {}
~SharedMemory() override;
/// Copy constructors
SharedMemory(const SharedMemory &rhs)
: BaseIPC(rhs), shm_id_(rhs.shm_id_), shm_key_(rhs.shm_key_), shmat_addr_(rhs.shmat_addr_) {}
SharedMemory &operator=(const SharedMemory &rhs) {
if (&rhs != this) {
shm_id_ = rhs.shm_id_;
shm_key_ = rhs.shm_key_;
shmat_addr_ = rhs.shmat_addr_;
BaseIPC::operator=(rhs);
}
return *this;
}
/// Move constructors
SharedMemory(SharedMemory &&rhs) noexcept : BaseIPC(std::move(rhs)) {
shm_id_ = rhs.shm_id_;
shm_key_ = rhs.shm_key_;
shmat_addr_ = rhs.shmat_addr_;
rhs.shm_id_ = -1;
rhs.shm_key_ = -1;
rhs.shmat_addr_ = nullptr;
}
SharedMemory &operator=(SharedMemory &&rhs) noexcept {
if (&rhs != this) {
shm_id_ = rhs.shm_id_;
shm_key_ = rhs.shm_key_;
shmat_addr_ = rhs.shmat_addr_;
rhs.shm_id_ = -1;
rhs.shm_key_ = -1;
rhs.shmat_addr_ = nullptr;
BaseIPC::operator=(std::move(rhs));
}
return *this;
}
/// \brief Set the public key
void SetPublicKey(key_t public_key) { shm_key_ = public_key; }
/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
void *SharedMemoryBaseAddr() { return shmat_addr_; }
/// \brief Attach to shared memory
/// \return Status object
Status Attach();
/// Detach from shared memory
/// \return Status object
Status Detach();
/// Create shared memory
/// \return Status object
Status Create(int64_t sz);
/// Destroy shared memory
/// \return Status object
Status Destroy();
/// \brief Return the shared memory id
shm_id_t GetSharedMemoryId() const { return shm_id_; }
/// \brief Get number of processes attached to the shared memory
/// \return Status object
Status GetNumAttached(int32_t *num);
private:
shm_id_t shm_id_;
shm_key_t shm_key_;
void *shmat_addr_;
};
/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
Status PortToFtok(int port, SharedMemory::shm_key_t *);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_

File diff suppressed because it is too large Load Diff

@ -250,5 +250,27 @@ Status GetStatRequest::PostReply() {
stat_.cache_service_state = msg->state();
return Status::OK();
}
Status ListSessionsRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
auto session_vector = msg->sessions();
for (auto i = 0; i < session_vector->size(); ++i) {
SessionCacheInfo current_info;
CacheServiceStat stats;
auto current_session_info = session_vector->Get(i);
current_info.session_id = current_session_info->session_id();
current_info.connection_id = current_session_info->connection_id();
stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
stats.min_row_id = current_session_info->stats()->min_row_id();
stats.max_row_id = current_session_info->stats()->max_row_id();
stats.cache_service_state = current_session_info->stats()->state();
current_info.stats = stats; // fixed length struct. = operator is safe
session_info_list_.push_back(current_info);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -46,6 +46,13 @@ struct CacheServiceStat {
int8_t cache_service_state;
};
/// \brief Info structure ListSessionsRequest
struct SessionCacheInfo {
session_id_type session_id;
connection_id_type connection_id;
CacheServiceStat stats;
};
/// \brief CacheClient communicates with CacheServer using Requests.
class BaseRequest {
public:
@ -54,7 +61,7 @@ class BaseRequest {
kCacheRow = 0,
kBatchFetchRows = 1,
kCreateCache = 2,
kPurgeCache = 3,
kGetCacheMissKeys = 3,
kDestroyCache = 4,
kGetStat = 5,
kCacheSchema = 6,
@ -65,6 +72,9 @@ class BaseRequest {
kAllocateSharedBlock = 11,
kFreeSharedBlock = 12,
kStopService = 13,
kHeartBeat = 14,
kToggleWriteMode = 15,
kListSessions = 16,
// Add new request before it.
kRequestUnknown = 32767
};
@ -73,6 +83,7 @@ class BaseRequest {
friend class CacheServerRequest;
friend class CacheClientGreeter;
friend class CacheClientRequestTag;
friend class CacheClient;
/// \brief Base class of a cache server request
/// \param type Type of the request
@ -119,7 +130,7 @@ class FreeSharedBlockRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(addr));
}
~FreeSharedBlockRequest() = default;
~FreeSharedBlockRequest() override = default;
};
/// \brief Request to cache a single TensorRow
@ -136,7 +147,7 @@ class CacheRowRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie);
}
~CacheRowRequest() = default;
~CacheRowRequest() override = default;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
@ -183,7 +194,7 @@ class BatchFetchRequest : public BaseRequest {
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass);
~BatchFetchRequest() = default;
~BatchFetchRequest() override = default;
Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);
private:
@ -203,7 +214,7 @@ class CreateCacheRequest : public BaseRequest {
/// \param flag Attributes of the cache.
explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone);
~CreateCacheRequest() = default;
~CreateCacheRequest() override = default;
void ParseResult(connection_id_type *id, std::string *out) {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
*id = p->connection_id();
@ -218,14 +229,15 @@ class CreateCacheRequest : public BaseRequest {
CreateCacheFlag flag_;
};
/// \brief Request to purge a cache.
class PurgeCacheRequest : public BaseRequest {
/// \brief Request to get all the keys not present at the server.
/// \note Only applicable to mappable case
class GetCacheMissKeysRequest : public BaseRequest {
public:
friend class CacheServer;
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) {
explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) {
rq_.set_connection_id(connection_id);
}
~PurgeCacheRequest() = default;
~GetCacheMissKeysRequest() override = default;
};
/// \brief Request to destroy a cache
@ -235,7 +247,7 @@ class DestroyCacheRequest : public BaseRequest {
explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) {
rq_.set_connection_id(connection_id);
}
~DestroyCacheRequest() = default;
~DestroyCacheRequest() override = default;
};
/// \brief Obtain the statistics of the current connection
@ -247,7 +259,7 @@ class GetStatRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
}
~GetStatRequest() = default;
~GetStatRequest() override = default;
/// \brief Override base function to process the result.
Status PostReply() override;
@ -269,7 +281,7 @@ class CacheSchemaRequest : public BaseRequest {
explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) {
rq_.set_connection_id(connection_id);
}
~CacheSchemaRequest() = default;
~CacheSchemaRequest() override = default;
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
};
@ -281,7 +293,7 @@ class FetchSchemaRequest : public BaseRequest {
explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) {
rq_.set_connection_id(connection_id);
}
~FetchSchemaRequest() = default;
~FetchSchemaRequest() override = default;
Status PostReply() override;
@ -300,7 +312,7 @@ class BuildPhaseDoneRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie_);
}
~BuildPhaseDoneRequest() = default;
~BuildPhaseDoneRequest() override = default;
private:
std::string cookie_;
@ -313,7 +325,7 @@ class DropSessionRequest : public BaseRequest {
explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) {
rq_.mutable_connection_info()->operator=(cinfo);
}
~DropSessionRequest() = default;
~DropSessionRequest() override = default;
};
class GenerateSessionIdRequest : public BaseRequest {
@ -325,11 +337,36 @@ class GenerateSessionIdRequest : public BaseRequest {
rq_.set_connection_id(0);
}
~GenerateSessionIdRequest() = default;
~GenerateSessionIdRequest() override = default;
session_id_type GetSessionId() { return atoi(reply_.result().data()); }
};
class ListSessionsRequest : public BaseRequest {
public:
friend class CacheServer;
ListSessionsRequest() : BaseRequest(RequestType::kListSessions) {
// This request is not specific to any cache or session
rq_.set_connection_id(0);
}
~ListSessionsRequest() override = default;
/// \brief Override base function to process the result.
Status PostReply() override;
void GetSessionCacheInfo(std::vector<SessionCacheInfo> *info) {
if (info != nullptr) {
(*info) = session_info_list_;
}
}
std::vector<SessionCacheInfo> GetSessionCacheInfo() { return session_info_list_; }
private:
std::vector<SessionCacheInfo> session_info_list_;
};
class AllocateSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
@ -338,7 +375,7 @@ class AllocateSharedBlockRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(requestedSz));
}
~AllocateSharedBlockRequest() = default;
~AllocateSharedBlockRequest() override = default;
/// \brief On return from the server, we get the (relative) address where
/// the free block is located.
@ -349,11 +386,15 @@ class AllocateSharedBlockRequest : public BaseRequest {
}
};
class ShutdownRequest : public BaseRequest {
class ToggleWriteModeRequest : public BaseRequest {
public:
friend class CacheServer;
ShutdownRequest() : BaseRequest(RequestType::kStopService) {}
~ShutdownRequest() = default;
explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off)
: BaseRequest(RequestType::kToggleWriteMode) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(on_off ? "on" : "off");
}
~ToggleWriteModeRequest() override = default;
};
} // namespace dataset
} // namespace mindspore

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save