Phase 2 of CacheOp

pull/4476/head
Jesse Lee 5 years ago
parent b11ef57b65
commit 8a08d0c37b

@ -24,6 +24,11 @@ if (ENABLE_TDTQUE)
add_definitions(-D ENABLE_TDTQUE) add_definitions(-D ENABLE_TDTQUE)
message(STATUS "TDT queue is enabled") message(STATUS "TDT queue is enabled")
endif () endif ()
if (MS_BUILD_GRPC)
set (ENABLE_CACHE true)
add_definitions(-D ENABLE_CACHE)
message(STATUS "Cache is enabled")
endif()
# conde coverage # conde coverage
# option(ENABLE_COVERAGE "Enable code coverage report" OFF) # option(ENABLE_COVERAGE "Enable code coverage report" OFF)
@ -47,10 +52,6 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
################## Include sub-modules ############################### ################## Include sub-modules ###############################
add_subdirectory(util) add_subdirectory(util)
add_subdirectory(core) add_subdirectory(core)
@ -70,8 +71,6 @@ add_dependencies(engine-datasetops-source-sampler core)
add_dependencies(engine-datasetops core) add_dependencies(engine-datasetops core)
add_dependencies(engine-datasetops-mapop core) add_dependencies(engine-datasetops-mapop core)
add_dependencies(engine-opt core) add_dependencies(engine-opt core)
add_dependencies(engine-cache-client core)
add_dependencies(engine-cache-server core)
add_dependencies(engine-perf core) add_dependencies(engine-perf core)
add_dependencies(engine-gnn core) add_dependencies(engine-gnn core)
add_dependencies(engine core) add_dependencies(engine core)
@ -85,7 +84,11 @@ endif()
if (ENABLE_TDTQUE) if (ENABLE_TDTQUE)
add_dependencies(engine-tdt core) add_dependencies(engine-tdt core)
endif () endif ()
if (ENABLE_CACHE)
add_dependencies(engine-datasetops engine-cache-client)
add_dependencies(engine-cache-client core)
add_dependencies(engine-cache-server core)
endif ()
################### Create _c_dataengine Library ###################### ################### Create _c_dataengine Library ######################
set(submodules set(submodules
$<TARGET_OBJECTS:core> $<TARGET_OBJECTS:core>
@ -105,7 +108,6 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops> $<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt> $<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client> $<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine> $<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text> $<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels> $<TARGET_OBJECTS:text-kernels>
@ -123,8 +125,6 @@ else ()
add_library(_c_dataengine SHARED ${submodules}) add_library(_c_dataengine SHARED ${submodules})
endif () endif ()
add_dependencies(_c_dataengine generated_engine_files)
if (ENABLE_PYTHON) if (ENABLE_PYTHON)
set_target_properties(_c_dataengine PROPERTIES set_target_properties(_c_dataengine PROPERTIES
PREFIX "${PYTHON_MODULE_PREFIX}" PREFIX "${PYTHON_MODULE_PREFIX}"
@ -187,6 +187,6 @@ else()
endif () endif ()
endif() endif()
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") if (MS_BUILD_GRPC)
target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++)
endif() endif()

@ -22,7 +22,25 @@ namespace dataset {
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(py::init<uint32_t, uint64_t, bool>()); .def(
py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) {
std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize(
prefetch_sz);
THROW_IF_ERROR(builder.Build(&cc));
return cc;
}))
.def("GetStat", [](CacheClient &cc) {
CacheServiceStat stat{};
THROW_IF_ERROR(cc.GetStat(&stat));
return stat;
});
(void)py::class_<CacheServiceStat>(*m, "CacheServiceStat")
.def(py::init<>())
.def_readwrite("avg_cache_sz", &CacheServiceStat::avg_cache_sz)
.def_readwrite("num_mem_cached", &CacheServiceStat::num_mem_cached)
.def_readwrite("num_disk_cached", &CacheServiceStat::num_disk_cached);
})); }));
} // namespace dataset } // namespace dataset

@ -72,7 +72,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255; constexpr uint8_t kCVInvalidType = 255;
using connection_id_type = int64_t; using connection_id_type = uint64_t;
using session_id_type = uint32_t;
using row_id_type = int64_t; using row_id_type = int64_t;
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -20,10 +20,8 @@ if (ENABLE_PYTHON)
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
endif() endif()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop)
if (ENABLE_TDTQUE) if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf add_dependencies(engine engine-tdt)
engine-cache-client engine-cache-server engine-datasetops-mapop)
else ()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server engine-datasetops-mapop)
endif () endif ()

@ -1,8 +1,47 @@
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-cache-client OBJECT add_library(engine-cache-client OBJECT
cache_client.cc cache_client.cc
cache_fbb.cc
cache_request.cc) cache_request.cc)
add_library(engine-cache-server OBJECT
if (ENABLE_CACHE)
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc)
add_library(engine-cache-server OBJECT
${CACHE_GRPC_SRCS}
cache_grpc_server.cc
cache_arena.cc
cache_service.cc cache_service.cc
cache_server.cc) cache_server.cc)
add_executable(cache_server cache_main.cc)
target_link_libraries(cache_server
engine-cache-server
$<TARGET_OBJECTS:utils>
mindspore
mindspore::glog
mindspore::protobuf
mindspore::grpc++
mindspore_gvar
${PYTHON_LIBRARIES}
${SECUREC_LIBRARY}
pthread)
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES} mindspore::glog)
add_dependencies(engine-cache-server generated_engine_files)
else ()
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS})
endif ()
add_dependencies(engine-cache-client generated_engine_files)

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <iostream>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include "minddata/dataset/engine/cache/cache_admin_arg.h"
namespace ds = mindspore::dataset;
int main(int argc, char **argv) {
ds::Status rc;
ds::CacheAdminArgHandler args;
std::stringstream arg_stream;
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif
std::string warningMsg;
warningMsg.reserve(512);
warningMsg += "WARNING:\n";
warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research";
warningMsg += " purposes at this time.\n";
warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n";
// A warning message until the code is mature enough.
std::cerr << warningMsg << std::endl;
if (argc == 1) {
args.Help();
return 0;
}
// ingest all the args into a string stream for parsing
for (int i = 1; i < argc; ++i) {
arg_stream << " " << std::string(argv[i]);
}
// Parse the args
rc = args.ParseArgStream(&arg_stream);
if (!rc.IsOk()) {
std::cerr << rc.ToString() << std::endl;
return 1;
}
// Execute the command
rc = args.RunCommand();
if (!rc.IsOk()) {
std::cerr << rc.ToString() << std::endl;
return 1;
}
return 0;
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <sstream>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/cache/cache_client.h"
namespace mindspore {
namespace dataset {
class CacheAdminArgHandler {
public:
static constexpr int32_t kDefaultPort = 50052;
static constexpr int32_t kDefaultNumWorkers = 32;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static const char kDefaultHost[];
static const char kServerBinary[];
static const char kDefaultSpillDir[];
// These are the actual command types to execute
enum class CommandId : int16_t {
kCmdHelp = 0,
kCmdStart = 1,
kCmdStop = 2,
kCmdGenerateSession = 3,
kCmdDestroySession = 4,
kCmdUnknown = 32767
};
CacheAdminArgHandler();
~CacheAdminArgHandler() = default;
Status ParseArgStream(std::stringstream *arg_stream);
Status RunCommand();
void Help();
private:
// These are the supported argument string integer mappings
enum class ArgValue : int16_t {
kArgUnknown = 0, // Must be at position 0. invalid map lookups in arg_map_ default to value 0
kArgStart = 1,
kArgStop = 2,
kArgHost = 3,
kArgPort = 4,
kArgHelp = 5,
kArgGenerateSession = 6,
kArgDestroySession = 7,
kArgSpillDir = 8,
kArgNumWorkers = 9,
kArgSharedMemorySize = 10,
kArgLogLevel = 11,
kArgNumArgs = 12 // Must be the last position to provide a count
};
Status StartServer();
Status StopServer();
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status Validate();
CommandId command_id_;
int32_t port_;
int32_t num_workers_;
int32_t shm_mem_sz_;
int32_t log_level_;
session_id_type session_id_;
std::string hostname_;
std::string spill_dir_;
std::string trailing_args_;
std::map<std::string, ArgValue> arg_map_;
std::map<ArgValue, bool> used_args_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
: Arena::Arena(val_in_GB * 1024), port_(port), shmid_(-1) {}
CachedSharedMemoryArena::~CachedSharedMemoryArena() {
#if CACHE_LOCAL_CLIENT
if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) {
shmdt(this->ptr_);
}
this->ptr_ = nullptr;
if (shmid_ != -1) {
shmctl(shmid_, IPC_RMID, nullptr);
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
#endif
}
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
#if CACHE_LOCAL_CLIENT
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
if (ba == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
// Transfer the ownership of this pointer. Any future error in the processing we will have
// the destructor of *out to deal.
(*out).reset(ba);
// Generate the ftok using a combination of port.
int err;
auto shm_key = PortToFtok(port, &err);
if (shm_key == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
ba->shmid_ = shmget(shm_key, ba->size_in_bytes_, IPC_CREAT | IPC_EXCL | access_mode);
if (ba->shmid_) {
ba->ptr_ = shmat(ba->shmid_, nullptr, 0);
if (ba->ptr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
} else {
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
}
uint64_t num_blks = ba->size_in_bytes_ / ARENA_BLK_SZ;
MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << ".";
ba->tr_.Insert(0, num_blks);
#endif
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
#include <memory>
#include <string>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
namespace mindspore {
namespace dataset {
/// This is a derived class of Arena but resides in shared memory
class CachedSharedMemoryArena : public Arena {
public:
~CachedSharedMemoryArena() override;
/// \brief Create an Arena in shared memory
/// \param[out] p_ba Pointer to a unique_ptr
/// \param shmkey Shared memory key
/// \param val_in_GB size of shared memory in gigabyte
/// \return Status object
static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB);
/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
/// we can't return the absolute address as this makes no sense
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return this->ptr_; }
private:
int32_t port_;
int shmid_;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_

File diff suppressed because it is too large Load Diff

@ -23,9 +23,13 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "minddata/dataset/core/config_manager.h"
#ifdef ENABLE_CACHE
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#else
#include "minddata/dataset/engine/cache/stub/cache_grpc_client.h"
#endif
#include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/lock.h" #include "minddata/dataset/util/lock.h"
namespace mindspore { namespace mindspore {
@ -35,18 +39,120 @@ namespace dataset {
/// rows, etc. /// rows, etc.
class CacheClient { class CacheClient {
public: public:
friend class CacheMergeOp;
/// \brief A builder to help creating a CacheClient object
class Builder {
public:
Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
hostname_ = "127.0.0.1";
port_ = 50052;
num_workers_ = cfg->num_parallel_workers();
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
}
/// Setter function to set the session id
/// \param session_id
/// \return Builder object itself.
Builder &SetSessionId(session_id_type session_id) {
session_id_ = session_id;
return *this;
}
/// Setter function to set the cache memory size
/// \param cache_mem_sz
/// \return Builder object itself
Builder &SetCacheMemSz(uint64_t cache_mem_sz) {
cache_mem_sz_ = cache_mem_sz;
return *this;
}
/// Setter function to spill attribute
/// \param spill
/// Builder object itself
Builder &SetSpill(bool spill) {
spill_ = spill;
return *this;
}
/// Setter function to set rpc hostname
/// \param host
/// \return Builder object itself
Builder &SetHostname(std::string host) {
hostname_ = std::move(host);
return *this;
}
/// Setter function to set tcpip port
/// \param port
/// \return Builder object itself.
Builder &SetPort(int32_t port) {
port_ = port;
return *this;
}
/// Setter function to set number of async rpc workers
/// \param num_workers
/// \return Builder object itself
Builder &SetNumWorkers(int32_t num_workers) {
num_workers_ = num_workers;
return *this;
}
/// Setter function to set prefetch amount for fetching rows from cache server
/// \param prefetch_sz
/// \return Builder object itself
Builder &SetPrefetchSize(int32_t prefetch_sz) {
prefetch_size_ = prefetch_sz;
return *this;
}
/// Getter functions
session_id_type getSessionId() const { return session_id_; }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
Status SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
return Status::OK();
}
Status Build(std::shared_ptr<CacheClient> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_IF_NOT_OK(SanityCheck());
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_,
prefetch_size_);
return Status::OK();
}
private:
session_id_type session_id_;
uint64_t cache_mem_sz_;
bool spill_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t prefetch_size_;
};
/// \brief Constructor /// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline /// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory /// \param spill Spill to disk if out of memory
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port,
int32_t num_workers, int32_t prefetch_size);
/// \brief Destructor /// \brief Destructor
~CacheClient() = default; ~CacheClient() { (void)comm_->ServiceStop(); }
/// \brief Getter function for returning the current session id
/// \return session id
uint64_t session_id() const { return session_id_; }
/// \brief Send a TensorRow to the cache server /// \brief Send a TensorRow to the cache server
/// \param[in] row /// \param[in] row
@ -83,14 +189,7 @@ class CacheClient {
/// \brief Get the statistics from a cache. /// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object /// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object /// \return Status object
struct ServiceStat { Status GetStat(CacheServiceStat *);
int64_t num_mem_cached;
int64_t num_disk_cached;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};
Status GetStat(ServiceStat *);
/// \brief Cache the schema at the cache server /// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema /// \param map The unordered map of the schema
@ -122,18 +221,45 @@ class CacheClient {
/// \return Cookie /// \return Cookie
std::string cookie() const { return cookie_; } std::string cookie() const { return cookie_; }
/// \brief Send a request async to the server
/// \param rq BaseRequest
/// \return Status object
Status PushRequest(std::shared_ptr<BaseRequest> rq) const;
/// \brief If the remote server supports local bypass using shared memory
/// \return boolean value
bool SupportLocalClient() const { return local_bypass_; }
/// \brief Return the base memory address if we attach to any shared memory.
auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); }
/// Getter functions
session_id_type session_id() const { return cinfo_.session_id(); }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
private: private:
mutable RWLock mux_; mutable RWLock mux_;
uint64_t cache_mem_sz_; uint64_t cache_mem_sz_;
bool spill_; bool spill_;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache. // sharing of the cache.
uint32_t session_id_; CacheClientInfo cinfo_;
uint32_t cache_crc_;
// The server_connection_id_ is the actual id we use for operations after the cache is built // The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_; connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server. // Some magic cookie returned from the cache server.
std::string cookie_; std::string cookie_;
// Comm layer
bool local_bypass_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t prefetch_size_;
mutable std::shared_ptr<CacheClientGreeter> comm_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
/// \note This header file contains common header files and some inlines used by
/// both client and server side codes. Do not put code that is not common here.
/// There are client and server specific header files.
// On platform like Windows, we may support only tcp/ip clients
#if !defined(_WIN32) && !defined(_WIN64)
#define CACHE_LOCAL_CLIENT 1
#endif
#ifdef CACHE_LOCAL_CLIENT
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#else
typedef int key_t;
#endif
#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
#include <string>
#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
namespace mindspore {
namespace dataset {
/// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported
/// on the platform) when the amount of bytes sent is greater than the following number.
/// For too small amount, we won't get any benefit using shared memory method because we need
/// two rpc requests to use shared memory method.
constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
constexpr static uint32_t kDataIsInSharedMemory = 2;
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object
inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
reply->set_rc(static_cast<google::int32>(rc.get_code()));
reply->set_msg(rc.ToString());
}
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
/// \param port
/// \return unix socket url
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }
/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
inline key_t PortToFtok(int port, int *err) {
key_t shmkey = -1;
#ifdef CACHE_LOCAL_CLIENT
const std::string unix_path = PortToUnixSocketPath(port);
shmkey = ftok(unix_path.data(), 'a');
if (err != nullptr && shmkey == (key_t)-1) {
*err = errno;
}
#endif
return shmkey;
}
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_

@ -0,0 +1,151 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_fbb.h"
namespace mindspore {
namespace dataset {
/// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
/// \note Not to be called by outside world
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
}
#undef CASE
TensorMetaMsgBuilder ts_builder(*fbb);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}
Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
RETURN_UNEXPECTED_IF_NULL(out_fbb);
auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
try {
fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
}
auto column_off = fbb->CreateVector(v);
auto data_sz_off = fbb->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
}
(*out_fbb) = std::move(fbb);
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
}
#undef CASE
DataType type(dest);
std::shared_ptr<Tensor> ts;
RETURN_IF_NOT_OK(
Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
/// This header contains some serialize and deserialize functions for tensor row using
/// Google Flatbuffer
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief Function to serialize TensorRow header used by CacheRowRequest
/// \param row TensorRow
/// \param fbb [in/out] fbb that contains the serialized data
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *fbb);
/// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row.
/// \param col_ts A serialized version of Tensor meta data
/// \param data Tensor data wrapped in a slice
/// \param out Tensor
/// \return Status object
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
option cc_enable_arenas = true;
// The session_id and crc work together to uniquely identify this particular cache and allow
// sharing of the cache.
message CacheClientInfo {
uint32 session_id = 1;
uint32 crc = 2;
}
message CacheRequest {
// Type of rpc request
int32 type = 1;
// Extra optional flag used by individual request if needed
uint32 flag = 2;
oneof connect_info {
// The server_connection_id is the actual id we use for operations after the cache is built
int64 connection_id = 3;
// But some request like CreateCache we have to use the session id and crc to connect to the server.
CacheClientInfo connection_info = 4;
}
// Everything else is just vector of buffers
repeated bytes buf_data = 5;
}
message CacheReply {
int32 rc = 1;
string msg = 2;
// Extra optional flag used by individual request if needed
uint32 flag = 3;
// What the server send back is a plain buffer
bytes result = 4;
}
service CacheServerGreeter {
rpc CacheServerRequest (CacheRequest) returns (CacheReply) {}
}

@ -0,0 +1,161 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#include <chrono>
namespace mindspore {
namespace dataset {
Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag) {
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK(tag->base_rq_->Prepare());
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq);
tag->rpc_->StartCall();
// Last step is we release the ownership and transfer it to the completion queue.
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
auto ccReqTag = tag.release();
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_,
ccReqTag); // inject this object into the completion queue
return Status::OK();
}
CacheClientGreeter::~CacheClientGreeter() {
(void)ServiceStop();
// Detach from shared memory if any
if (shmat_addr_ != nullptr) {
shmdt(shmat_addr_);
shmat_addr_ = nullptr;
}
}
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers)
: num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) {
grpc::ChannelArguments args;
// We need to bump up the message size to unlimited. The default receiving
// message limit is 4MB which is not big enough.
args.SetMaxReceiveMessageSize(-1);
#if CACHE_LOCAL_CLIENT
// Try connect locally to the unix_socket first as the first preference
// Need to resolve hostname to ip address rather than to do a string compare
if (hostname == "127.0.0.1") {
std::string target = "unix://" + PortToUnixSocketPath(port);
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
} else {
#endif
std::string target = hostname + ":" + std::to_string(port);
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
#if CACHE_LOCAL_CLIENT
}
#endif
stub_ = CacheServerGreeter::NewStub(channel_);
}
Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) {
*local_bypass = false;
#if CACHE_LOCAL_CLIENT
int err;
shm_key_ = PortToFtok(port, &err);
if (shm_key_ == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
// Attach to the shared memory
shm_id_ = shmget(shm_key_, 0, 0);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
}
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
*local_bypass = true;
#endif
return Status::OK();
}
Status CacheClientGreeter::DoServiceStart() {
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(DispatchWorkers(num_workers_));
return Status::OK();
}
Status CacheClientGreeter::DoServiceStop() {
// Shutdown the queue. We don't accept any more new incomers.
cq_.Shutdown();
// Shutdown the TaskGroup.
vg_.interrupt_all();
vg_.join_all(Task::WaitFlag::kNonBlocking);
// Drain the queue
bool success;
void *tag;
while (cq_.Next(&tag, &success)) {
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
delete r;
}
return Status::OK();
}
Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq));
return tag->MakeCall(stub_.get(), &cq_, std::move(tag));
}
Status CacheClientGreeter::WorkerEntry() {
TaskManager::FindMe()->Post();
do {
bool success;
void *tag;
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_.AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
auto rq = reinterpret_cast<CacheClientRequestTag *>(tag);
if (success) {
auto &rc = rq->rc_;
if (!rc.ok()) {
auto error_code = rq->rc_.error_code();
std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
}
// Notify the waiting thread.
rq->Notify();
}
// We can now free the memory
delete rq;
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();
} else {
// Queue is drained.
break;
}
} while (true);
return Status::OK();
}
Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) {
auto f = std::bind(&CacheClientGreeter::WorkerEntry, this);
for (auto i = 0; i < num_workers; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,102 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
/// \brief A client view of gRPC request
/// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC
/// completion queue. The thread that makes the rpc request will wait on a wait post
/// area for the reply to come back. Since this tag will be deleted from memory and
/// we thus we need to work on a shared pointer of the BaseRequest such that its
/// use count is at least two. Otherwise either thread will be referencing stale memory.
/// \see CacheServerRequest
class CacheClientRequestTag {
public:
friend class CacheClientGreeter;
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {}
~CacheClientRequestTag() = default;
/// \brief Make a RPC call
/// \param stub from CacheClientGreeter
/// \param cq from CacheClientGreeter
/// \return Status object
static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag);
/// \brief Notify the client that a result has come back from the server
void Notify() { base_rq_->wp_.Set(); }
private:
std::shared_ptr<BaseRequest> base_rq_;
grpc::Status rc_;
grpc::ClientContext ctx_;
std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_;
};
/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
/// \see BaseRequest
class CacheClientGreeter : public Service {
friend class CacheClient;
public:
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers);
~CacheClientGreeter();
/// Override base Service class
Status DoServiceStart() override;
Status DoServiceStop() override;
/// \brief Send the request to the server
/// \return Status object
Status HandleRequest(std::shared_ptr<BaseRequest> rq);
/// \brief A handful of threads will be handling async reply from the server
/// \return
Status WorkerEntry();
/// \brief Kick off threads to receive reply from the server
Status DispatchWorkers(int32_t num_workers);
/// \brief Attach to shared memory for local client
/// \note Called after we have established a connection.
/// \return Status object.
Status AttachToSharedMemory(int32_t port, bool *local_bypass);
/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
private:
std::shared_ptr<grpc::Channel> channel_;
std::unique_ptr<CacheServerGreeter::Stub> stub_;
grpc::CompletionQueue cq_;
TaskGroup vg_;
int32_t num_workers_;
key_t shm_key_;
int32_t shm_id_;
void *shmat_addr_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_

@ -0,0 +1,203 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include <limits>
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/path.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) {
// Setup a path for unix socket.
unix_socket_ = PortToUnixSocketPath(port);
// We can't generate the ftok key yet until the unix_socket_ is created
}
void CacheServerGreeterImpl::Shutdown() {
if (server_) {
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
server_->Shutdown(deadline);
}
// Always shutdown the completion queue after the server.
if (cq_) {
cq_->Shutdown();
// We need to drain the queue. All the tag is coming from
// the Services pool which will be shutdown as well. So we
// ignore the tag.
void *tag;
bool success;
while (cq_->Next(&tag, &success)) {
}
}
}
CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); }
Status CacheServerGreeterImpl::IpcResourceCleanup() {
#if CACHE_LOCAL_CLIENT
int err;
auto shm_key = PortToFtok(port_, &err);
// We are expecting the unix path doesn't exist.
if (shm_key == (key_t)-1) {
return Status::OK();
}
// Attach to the shared memory
auto shm_id = shmget(shm_key, 0, 0);
if (shm_id == -1) {
return Status::OK();
}
struct shmid_ds ds {};
auto inx = shmctl(shm_id, IPC_STAT, &ds);
if (inx == -1) {
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (ds.shm_nattch == 0) {
// Stale shared memory from last time.
// Remove both the memory and the socket path
inx = shmctl(shm_id, IPC_RMID, nullptr);
if (inx == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
Path p(unix_socket_);
(void)p.Remove();
} else {
// Server is already up.
MS_LOG(ERROR) << "Cache server is already up and running";
// We return a duplicate error. The main() will intercept
// and output a proper message
return Status(StatusCode::kDuplicateKey);
}
#endif
return Status::OK();
}
Status CacheServerGreeterImpl::Run() {
// To listen on all interfaces, use 0.0.0.0
// Use 127.0.0.1 if just locally on the same machine.
std::string host("0.0.0.0"); // listen on all interfaces.
std::string server_address = host + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
int port_tcpip = 0;
#if CACHE_LOCAL_CLIENT
int port_local = 0;
// Check if we need to do clean up on the shared memory if the server
// came down unexpectedly like SEGV
RETURN_IF_NOT_OK(IpcResourceCleanup());
// We also optimize on local clients on the same machine using unix socket
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
#endif
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
builder.RegisterService(&svc_);
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
MS_LOG(INFO) << "Creation of local socket and shared memory successful";
#endif
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + ".";
}
#if CACHE_LOCAL_CLIENT
if (port_local == 0) {
errMsg += " Unable to create unix socket " + unix_socket_ + ".";
}
#endif
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) {
bool success;
void *tag;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CacheServerRequest
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
CacheServerRequest *p;
// Get a free tag from my free list.
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p));
RETURN_IF_NOT_OK((*p)(&svc_, cq_.get()));
do {
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_->AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
if (success) {
auto rq = static_cast<CacheServerRequest *>(tag);
RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get()));
}
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();
} else {
// Queue is drained.
break;
}
} while (true);
return Status::OK();
}
Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) {
auto myQID = getQid();
if (st_ == STATE::CREATE) {
st_ = STATE::PROCESS;
svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this);
} else if (st_ == STATE::PROCESS) {
// Get a new tag and handle the next request before we serve the current request.
// The tag will be recycled when its state is changed to FINISH
CacheServerRequest *next_rq;
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq));
RETURN_IF_NOT_OK((*next_rq)(svc, cq));
// Now we continue with the current request.
// First thing we need to extract the type from the incoming request.
// When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN.
type_ = static_cast<RequestType>(rq_.type());
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG(DEBUG) << "Handle request " << *this;
auto &cs = CacheServer::GetInstance();
RETURN_IF_NOT_OK(cs.PushRequest(myQID, this));
} else if (st_ == STATE::FINISH) {
MS_LOG(DEBUG) << *this << " Finished.";
// Return back to the free list.
RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this));
}
return Status::OK();
}
void CacheServerRequest::Print(std::ostream &out) const {
if (rq_.has_connection_info()) {
out << "Session Id: " << rq_.connection_info().session_id() << " CRC: " << rq_.connection_info().crc();
} else {
out << "Connection Id: " << rq_.connection_id();
}
out << " ";
BaseRequest::Print(out);
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
/// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects
/// and this class is used to translate from protobuf to structures understood by CacheService class.
/// \see CacheService
class CacheServerRequest : public BaseRequest {
public:
friend class CacheServer;
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
explicit CacheServerRequest(int32_t queue_id)
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown),
qid_(queue_id),
st_(STATE::CREATE),
responder_(&ctx_) {}
~CacheServerRequest() = default;
/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
/// functor will translate each protobuf into some form understood by by CacheService class.
/// \param svc Async service
/// \param cq Completion queue
/// \return Status object
Status operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq);
/// \brief Override the base class Print method
/// \param out
void Print(std::ostream &out) const override;
/// \brief Getter of the queue id
/// \return The queue where the request should go to
int32_t getQid() const { return qid_; }
private:
int32_t qid_;
Status rc_;
STATE st_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<CacheReply> responder_;
};
/// \brief Implementation of CacheServerGreeter
/// \note It is an async server
/// \see cache_grpc.proto
class CacheServerGreeterImpl final {
friend class CacheServer;
public:
explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb);
virtual ~CacheServerGreeterImpl();
/// \brief Brings up gRPC server
/// \return none
Status Run();
/// \brief Entry function to handle cache server request
Status HandleRequest(int32_t worker_id);
/// Return the shared memory pool.
/// \return Return the shared memory pool
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }
void Shutdown();
Status IpcResourceCleanup();
private:
int32_t port_;
size_t shm_pool_sz_in_gb_;
std::string unix_socket_;
CacheServerGreeter::AsyncService svc_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_

@ -0,0 +1,121 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_server.h"
#include <sys/types.h>
#include <unistd.h>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include <cstdlib>
namespace ds = mindspore::dataset;
int main(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;
// This executable is not to be called directly, and should be invoked by cache_admin executable.
if (argc != 7) {
rc = ds::Status(ds::StatusCode::kSyntaxError);
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
builder.SetRootDirectory(argv[1])
.SetNumWorkers(strtol(argv[2], nullptr, 10))
.SetPort(strtol(argv[3], nullptr, 10))
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10));
#ifdef USE_GLOG
FLAGS_minloglevel = strtol(argv[5], nullptr, 10);
#endif
auto daemonize_string = argv[6];
bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 ||
strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0;
// We always change directory to / on unix rather than using the directory where the cache_server
// is called. This is a standard procedure for daemonize a process on unix.
if (chdir("/") == -1) {
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
std::cerr << errMsg << std::endl;
return -1;
}
// Simple check of the parameters before we move on.
rc = builder.SanityCheck();
if (rc.IsError()) {
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif
if (daemonize) {
// fork the child process to become the daemon
pid_t pid = fork();
// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
std::cerr << err_msg << std::endl;
return errno;
} else if (pid > 0) {
// Parent
std::cerr << "cache server daemon process has been created as process id: " << pid
<< "\nCheck log file for any start up error" << std::endl;
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
return 0;
} else {
// Child process will continue from here if daemonize and parent has already exited.
// If we are running in the foreground, none of the code in block below will be run.
pid_t sid;
umask(0);
sid = setsid();
if (sid < 0) {
MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno);
return errno;
}
close(0);
close(1);
close(2);
}
}
// Dump the summary
MS_LOG(INFO) << builder << std::endl;
rc = builder.Build();
if (rc.IsOk()) {
ds::CacheServer &cs = ds::CacheServer::GetInstance();
// Kick off the threads. Loop forever and never return unless error.
rc = cs.Run();
if (rc.get_code() == ds::StatusCode::kDuplicateKey) {
std::string errMsg = "Server is already started";
MS_LOG(ERROR) << errMsg;
std::cerr << errMsg << std::endl;
return 0;
}
}
if (rc.IsError()) {
MS_LOG(ERROR) << rc.ToString();
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
return 0;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -24,8 +24,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <map> #include <map>
#include <set>
#include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h" #include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/cache_pool.h" #include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/lock.h" #include "minddata/dataset/util/lock.h"
@ -37,43 +40,131 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class BaseRequest;
/// \brief A server which provides CacheService services. /// \brief A server which provides CacheService services.
class CacheServer : public Service { class CacheServer : public Service {
public: public:
friend class Services; friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
class Builder {
public:
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {}
/// \brief Getter functions
const std::string &getTop() const { return top_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPort() const { return port_; }
int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; }
Builder &SetRootDirectory(std::string root) {
top_ = std::move(root);
return *this;
}
Builder &SetNumWorkers(int32_t n) {
num_workers_ = n;
return *this;
}
Builder &SetPort(int32_t p) {
port_ = p;
return *this;
}
Builder &SetSharedMemorySizeInGB(int32_t sz) {
shared_memory_sz_in_gb_ = sz;
return *this;
}
Status SanityCheck();
void Print(std::ostream &out) const {
out << "Summary of the cache server configuration\n"
<< "Spill directory: " << getTop() << "\n"
<< "Number of parallel workers: " << getNumWorkers() << "\n"
<< "Tcp/ip port: " << getPort() << "\n"
<< "Shared memory size (in GB): " << getSharedMemorySzInGb();
}
friend std::ostream &operator<<(std::ostream &out, const Builder &bld) {
bld.Print(out);
return out;
}
Status Build() {
RETURN_IF_NOT_OK(SanityCheck());
// We need to bring up the Task Manager by bringing up the Services singleton.
RETURN_IF_NOT_OK(Services::CreateInstance());
RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_));
return Status::OK();
}
private:
std::string top_;
int32_t num_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
};
CacheServer(const CacheServer &) = delete; CacheServer(const CacheServer &) = delete;
CacheServer &operator=(const CacheServer &) = delete; CacheServer &operator=(const CacheServer &) = delete;
CacheServer(CacheServer &&) = delete; CacheServer(CacheServer &&) = delete;
CacheServer &operator=(CacheServer &) = delete; CacheServer &operator=(CacheServer &) = delete;
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
Status DoServiceStart() override; Status DoServiceStart() override;
Status DoServiceStop() override; Status DoServiceStop() override;
~CacheServer() { (void)ServiceStop(); } ~CacheServer() { (void)ServiceStop(); }
static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port,
int32_t shared_memory_sz) {
std::call_once(init_instance_flag_, [&]() -> Status {
auto &svcManager = Services::GetInstance();
RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz));
return Status::OK();
});
return Status::OK();
}
static CacheServer &GetInstance() { return *instance_; }
/// \brief For the current demonstration, a cache client contacts cache server using a Queue. /// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq /// \param rq
/// \return Status object /// \return Status object
Status PushRequest(BaseRequest *rq) { Status PushRequest(int32_t queue_id, CacheServerRequest *rq) {
RETURN_UNEXPECTED_IF_NULL(rq); RETURN_UNEXPECTED_IF_NULL(rq);
RETURN_IF_NOT_OK(cache_q_->Add(rq)); RETURN_IF_NOT_OK(cache_q_->operator[](queue_id)->Add(rq));
return Status::OK(); return Status::OK();
} }
/// \\brief Kick off server threads. Never return unless error out.
Status Run();
/// \brief Get a free tag
/// \param q[in] pointer to a pointer to a CacheServerRequest
/// \return Status object
static Status GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q);
/// \brief Return a tag to the free list
/// \param p[in] pointer to already finished CacheServerRequest tag
/// \return Status object
static Status ReturnRequestTag(CacheServerRequest *p);
private: private:
static std::once_flag init_instance_flag_;
static CacheServer *instance_;
mutable RWLock rwLock_; mutable RWLock rwLock_;
std::string top_; std::string top_;
cache_index all_caches_; cache_index all_caches_;
std::shared_ptr<Queue<BaseRequest *>> cache_q_; std::set<session_id_type> history_sessions_;
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_;
std::shared_ptr<CacheServerGreeterImpl> comm_layer_;
std::shared_ptr<MemoryPool> mp_;
TaskGroup vg_; TaskGroup vg_;
int32_t num_workers_; int32_t num_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
std::atomic<bool> global_shutdown_;
/// \brief Constructor /// \brief Constructor
/// \param spill_path Top directory for spilling buffers to. /// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests. /// \param num_workers Number of threads for handling requests.
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb);
/// \brief Locate a cache service from connection id. /// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found /// \return Pointer to cache service. Null if not found
@ -82,16 +173,65 @@ class CacheServer : public Service {
/// \brief Create a cache service. We allow multiple clients to create the same cache service. /// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie. /// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object /// \return Status object
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, Status CreateService(CacheRequest *rq, CacheReply *reply);
std::string *out_cookie);
/// \brief Destroy a cache service
/// \param cs
/// \param rq
/// \return
Status DestroyCache(CacheService *cs, CacheRequest *rq);
Status PurgeCache(CacheService *cs);
/// \brief Entry point for all internal server threads.
Status ServerRequest(int32_t worker_id);
/// \brief Entry point for all grpc threads.
/// \return
Status RpcRequest(int32_t worker_id);
Status DestroySession(CacheRequest *rq);
/// \brief Create a connection id from a session id and a crc
/// \param session_id
/// \param crc
/// \return connection id
connection_id_type GetConnectionID(session_id_type session_id, uint32_t crc) const;
/// \brief Extract the session id from a connection id
/// \param connection_id
/// \return session id
session_id_type GetSessionID(connection_id_type connection_id) const;
/// \brief Generate a session ID for the client
/// \return Session ID
session_id_type GenerateSessionID() const;
/// \brief Handle kAllocateSharedBlock request
/// \param rq CacheRequest
/// \param reply CacheReply
/// \return Status object
Status AllocateSharedMemory(CacheRequest *rq, CacheReply *reply);
/// \brief Entry point for all server threads. /// \brief Handle kFreeSharedBlock request
Status ServerRequest(); /// \param rq
/// \return Status object
Status FreeSharedMemory(CacheRequest *rq);
/// \brief Handle kFastCacheRow request
/// \return Status object
Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply);
/// \brief Internal function to do row batch fetch
/// \param cs CacheService
/// \param rq Request
/// \param reply Reply
/// \return
Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply);
/// \brief A proper shutdown of the server
/// \return Status object
Status GlobalShutdown();
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save