!7933 Cache performance updates

Merge pull request !7933 from Jamie/CacheOp_dev
pull/7933/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit cacebd1211

@ -1,3 +1,4 @@
add_subdirectory(perf EXCLUDE_FROM_ALL)
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
@ -5,6 +6,18 @@ ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engin
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
# Try to find numa header file and its library
find_file(NUMA_HDR NAMES "numa.h")
if (EXISTS ${NUMA_HDR})
ADD_DEFINITIONS(-DNUMA_ENABLED)
MESSAGE("Numa package found")
endif ()
if (${CMAKE_SYSTEM_NAME} MATCHES "Linux")
ADD_DEFINITIONS(-DCACHE_LOCAL_CLIENT)
endif ()
add_library(engine-cache-client OBJECT
cache_client.cc
cache_fbb.cc
@ -20,8 +33,13 @@ if (ENABLE_CACHE)
${CACHE_GRPC_SRCS}
cache_grpc_server.cc
cache_arena.cc
cache_hw.cc
cache_numa.cc
cache_pool.cc
cache_service.cc
cache_server.cc)
cache_server.cc
storage_manager.cc
storage_container.cc)
add_executable(cache_server cache_main.cc)
target_link_libraries(cache_server
@ -39,6 +57,10 @@ if (ENABLE_CACHE)
target_link_libraries(cache_server mindspore::glog)
endif ()
if (EXISTS ${NUMA_HDR})
target_link_libraries(cache_server numa)
endif ()
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES})
@ -49,7 +71,7 @@ if (ENABLE_CACHE)
add_dependencies(engine-cache-server generated_engine_files)
else ()
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto)
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PROTO_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS})
endif ()

@ -18,6 +18,7 @@
#include <sys/stat.h>
#include <sys/wait.h>
#include <unistd.h>
#include <algorithm>
#include <cerrno>
#include <iomanip>
#include <iostream>
@ -31,7 +32,9 @@
namespace mindspore {
namespace dataset {
const int32_t CacheAdminArgHandler::kDefaultNumWorkers = std::thread::hardware_concurrency() > 2
? std::thread::hardware_concurrency() / 2
: 1;
const char CacheAdminArgHandler::kServerBinary[] = "cache_server";
const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp";
@ -304,8 +307,10 @@ Status CacheAdminArgHandler::Validate() {
}
// Additional checks here
if (num_workers_ < 1 || num_workers_ > 100)
return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100.");
auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100);
if (num_workers_ < 1 || num_workers_ > max_num_workers)
return Status(StatusCode::kSyntaxError,
"Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + ".");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1");
@ -354,13 +359,15 @@ Status CacheAdminArgHandler::RunCommand() {
std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo();
if (!session_info.empty()) {
std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached"
<< std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl;
<< std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::setw(10) << "Numa hit"
<< std::endl;
for (auto curr_session : session_info) {
std::string cache_id;
std::string stat_mem_cached;
std::string stat_disk_cached;
std::string stat_avg_cached;
int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF);
std::string stat_numa_hit;
uint32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF);
cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc);
stat_mem_cached =
(curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached);
@ -368,10 +375,12 @@ Status CacheAdminArgHandler::RunCommand() {
(curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached);
stat_avg_cached =
(curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz);
stat_numa_hit =
(curr_session.stats.num_numa_hit == 0) ? "n/a" : std::to_string(curr_session.stats.num_numa_hit);
std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12)
<< stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached
<< std::endl;
<< std::setw(10) << stat_numa_hit << std::endl;
}
} else {
std::cout << "No active sessions." << std::endl;

@ -21,6 +21,7 @@
#include <memory>
#include <string>
#include <sstream>
#include <thread>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/cache/cache_client.h"
@ -29,7 +30,7 @@ namespace dataset {
class CacheAdminArgHandler {
public:
static constexpr int32_t kDefaultNumWorkers = 32;
static const int32_t kDefaultNumWorkers;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static constexpr float kMemoryCapRatio = 0.8;

@ -17,7 +17,6 @@
#include <iomanip>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
#include "minddata/dataset/util/bit.h"
@ -59,6 +58,7 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool
: server_connection_id_(0),
cache_mem_sz_(cache_mem_sz),
spill_(spill),
client_id_(-1),
local_bypass_(false),
hostname_(std::move(hostname)),
port_(port),
@ -71,6 +71,22 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool
CacheClient::~CacheClient() {
cache_miss_keys_wp_.Set();
if (client_id_ != -1) {
try {
// Send a message to the server, saying I am done.
auto rq = std::make_shared<ConnectResetRequest>(server_connection_id_, client_id_);
Status rc = PushRequest(rq);
if (rc.IsOk()) {
rc = rq->Wait();
if (rc.IsOk()) {
MS_LOG(INFO) << "Disconnect from server successful";
}
}
} catch (const std::exception &e) {
// Can't do anything in destructor. So just log the error.
MS_LOG(ERROR) << e.what();
}
}
(void)comm_->ServiceStop();
}
@ -85,7 +101,7 @@ void CacheClient::Print(std::ostream &out) const {
}
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
auto rq = std::make_shared<CacheRowRequest>(this);
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
@ -104,7 +120,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
arr[i] = std::make_shared<CacheRowRequest>(this);
RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(arr[i]));
}
@ -118,7 +134,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient());
auto rq = std::make_shared<BatchFetchRequest>(this, row_id);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
int64_t mem_addr;
@ -167,7 +183,7 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheServiceStat stat{};
RETURN_IF_NOT_OK(GetStat(&stat));
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
if (stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase)) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
}
} else {
@ -183,18 +199,16 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
// Start the comm layer to receive reply
RETURN_IF_NOT_OK(comm_->ServiceStart());
// Initiate connection
auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag);
auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(PushRequest(rq));
Status rc = rq->Wait();
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
std::string cookie;
rq->ParseResult(&server_connection_id_, &cookie);
if (rc.IsOk()) {
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_ = cookie;
}
bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey);
// If we get kDuplicateKey, it just means we aren't the first one to create the cache,
// and we will continue to parse the result.
if (rc.get_code() == StatusCode::kDuplicateKey) {
RETURN_IF_NOT_OK(rq->PostReply());
}
if (success) {
// Attach to shared memory for local client
RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_));
}

@ -47,6 +47,9 @@ namespace dataset {
class CacheClient {
public:
friend class CacheMergeOp;
friend class CreateCacheRequest;
friend class CacheRowRequest;
friend class BatchFetchRequest;
/// \brief A builder to help creating a CacheClient object
class Builder {
@ -115,7 +118,7 @@ class CacheClient {
session_id_type GetSessionId() const { return session_id_; }
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
const std::string &GetHostname() const { return hostname_; }
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
@ -256,8 +259,10 @@ class CacheClient {
CacheClientInfo cinfo_;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server.
// Some magic cookie/id returned from the cache server.
std::string cookie_;
int32_t client_id_;
std::vector<int32_t> cpu_list_;
// Comm layer
bool local_bypass_;
std::string hostname_;

@ -20,11 +20,6 @@
/// both client and server side codes. Do not put code that is not common here.
/// There are client and server specific header files.
// On platform like Windows, we may support only tcp/ip clients
#if !defined(_WIN32) && !defined(_WIN64)
#define CACHE_LOCAL_CLIENT 1
#endif
#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
@ -50,6 +45,9 @@ constexpr static uint32_t kDataIsInSharedMemory = 2;
/// \brief Size of each message used in message queue.
constexpr static int32_t kSharedMessageSize = 2048;
/// \brief State of CacheService at the server.
enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object
@ -61,6 +59,22 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
/// \param port
/// \return unix socket url
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }
/// \brief Round up to the next 4k
inline int64_t round_up_4K(int64_t sz) {
// Since 4096 is a power of 2, a simple way to round up is add 4095 and mask off all the
// bits of 4095
return static_cast<uint64_t>(sz + 4095) & ~static_cast<uint64_t>(4095);
}
/// Memory policy
enum CachePoolPolicy : int8_t { kOnNode, kPreferred, kLocal, kInterleave, kNone };
/// Misc typedef
using worker_id_t = int32_t;
using numa_id_t = int32_t;
using cpu_id_t = int32_t;
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_

@ -32,12 +32,13 @@ message CacheRequest {
uint32 flag = 2;
oneof connect_info {
// The server_connection_id is the actual id we use for operations after the cache is built
int64 connection_id = 3;
uint64 connection_id = 3;
// But some request like CreateCache we have to use the session id and crc to connect to the server.
CacheClientInfo connection_info = 4;
}
int32 client_id = 5;
// Everything else is just vector of buffers
repeated bytes buf_data = 5;
repeated bytes buf_data = 6;
}
message CacheReply {

@ -74,6 +74,9 @@ Status CacheServerGreeterImpl::Run() {
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
MS_LOG(INFO) << "Creation of local socket and shared memory successful";
auto cs = CacheServer::GetInstance().GetHWControl();
// This shared memory is a hot memory and we will interleave among all the numa nodes.
cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L);
#endif
} else {
std::string errMsg = "Fail to start server. ";
@ -127,8 +130,13 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
st_ = STATE::PROCESS;
svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this);
} else if (st_ == STATE::PROCESS) {
auto &cs = CacheServer::GetInstance();
// Get a new tag and handle the next request before we serve the current request.
// The tag will be recycled when its state is changed to FINISH
// The tag will be recycled when its state is changed to FINISH.
// The number of free list queues is the same as the number of grpc threads.
// Where we get the free list it doesn't matter (as long we return it back to the right queue).
// We can round robin, use the qid or even use the worker id. We will use the free list queue
// where the current request comes from.
CacheServerRequest *next_rq;
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq));
RETURN_IF_NOT_OK((*next_rq)(svc, cq));
@ -138,8 +146,24 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
type_ = static_cast<RequestType>(rq_.type());
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG(DEBUG) << "Handle request " << *this;
auto &cs = CacheServer::GetInstance();
RETURN_IF_NOT_OK(cs.PushRequest(myQID, this));
// We will distribute the request evenly (or randomly) over all the numa nodes.
// The exception is BatchFetch which we need to pre-process here.
if (type_ == BaseRequest::RequestType::kBatchFetchRows) {
rc_ = cs.BatchFetchRows(&rq_, &reply_);
if (!rc_.IsInterrupted()) {
Status2CacheReply(rc_, &reply_);
st_ = CacheServerRequest::STATE::FINISH;
responder_.Finish(reply_, grpc::Status::OK, this);
} else {
return rc_;
}
} else {
// When the number of grpc workers is the same as the server workers, we will use this queue id
// and push to the corresponding queue.
bool random = cs.GetNumWorkers() != cs.GetNumGrpcWorkers();
worker_id_t worker_id = random ? cs.GetRandomWorker() : myQID;
RETURN_IF_NOT_OK(cs.PushRequest(worker_id, this));
}
} else if (st_ == STATE::FINISH) {
MS_LOG(DEBUG) << *this << " Finished.";
// Return back to the free list.

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#include <atomic>
#include <memory>
#include <string>
#include <utility>
@ -34,6 +35,7 @@ namespace dataset {
class CacheServerRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
explicit CacheServerRequest(int32_t queue_id)
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown),

@ -0,0 +1,220 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_hw.h"
#ifdef NUMA_ENABLED
#include <numa.h>
#endif
#include <sched.h>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <fstream>
#include <regex>
#include <thread>
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
CacheServerHW::CacheServerHW() {
num_cpus_ = std::thread::hardware_concurrency();
MS_LOG(DEBUG) << "Number of cpu(s) : " << num_cpus_;
#ifdef NUMA_ENABLED
if (numa_enabled()) {
MS_LOG(WARNING) << "Numa support enabled";
for (auto i = 0; i <= numa_max_node(); ++i) {
int64_t free_avail;
int64_t mem_avail = numa_node_size(i, &free_avail);
MS_LOG(INFO) << "Total physical/free RAM in bytes at node " << i << " : " << mem_avail << "/" << free_avail;
}
}
#endif
}
int64_t CacheServerHW::GetTotalSystemMemory() {
auto pages = sysconf(_SC_PHYS_PAGES);
auto page_size = sysconf(_SC_PAGE_SIZE);
auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size);
MS_LOG(INFO) << "Total physical RAM in bytes: " << total;
return total;
}
Status CacheServerHW::SetDefaultMemoryPolicy(CachePoolPolicy policy) {
#ifdef NUMA_ENABLED
if (numa_enabled()) {
// Set our default memory policy.
switch (policy) {
case kLocal:
numa_set_localalloc();
MS_LOG(DEBUG) << "Setting memory default policy to local node. Low level code may override the setting";
break;
case kInterleave:
numa_set_interleave_mask(numa_all_nodes_ptr);
MS_LOG(DEBUG) << "Numa affinity is turned off. Use interleave memory policy as default.";
break;
case kOnNode:
case kPreferred:
RETURN_STATUS_UNEXPECTED("Unsupported memory policy");
break;
case kNone:
default:
// No action taken.
break;
}
}
#endif
return Status::OK();
}
Status CacheServerHW::GetNumaNodeInfo() {
std::set<Path> numa_nodes_;
Path node(kSysNodePath);
auto it = Path::DirIterator::OpenDirectory(&node);
if (it == nullptr) {
MS_LOG(WARNING) << "Unable to open directory " << kSysNodePath << ". Skip scanning hardware info";
return Status::OK();
}
auto isdigit_string = [](const char *str) -> bool {
bool r = true;
for (auto i = 0; i < strlen(str); ++i) {
if (!std::isdigit(str[i])) {
r = false;
break;
}
}
return r;
};
// Look for name starts with 'node' and followed by digits.
const char kNodeName[] = "node";
while (it->hasNext()) {
auto p = it->next();
const std::string entry = p.Basename();
const char *name = entry.data();
if (strncmp(name, kNodeName, 4) == 0 && isdigit_string(name + strlen(kNodeName))) {
numa_nodes_.insert(p);
}
}
// There should be at least one. But if not found in any case, just move on the
// rest of the server start up.
if (numa_nodes_.empty()) {
MS_LOG(WARNING) << "No numa nodes ? Skip scanning hardware info";
return Status::OK();
}
// For each numa node, get a list of CPU that is associated with it.
const char kCpuList[] = "cpulist";
auto r = std::regex("[0-9]*-[0-9]*");
for (Path p : numa_nodes_) {
auto node_dir = p.Basename().data();
numa_id_t numa_node = strtol(node_dir + strlen(kNodeName), nullptr, 10);
Path f = p / kCpuList;
std::ifstream fs(f.toString());
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString());
std::string cpu_string;
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
int32_t cpu_cnt = 0;
while (getline(fs, cpu_string)) {
// Now we parse the content of cpu_string.
std::sregex_iterator iter(cpu_string.begin(), cpu_string.end(), r);
std::sregex_iterator end;
while (iter != end) {
auto match = iter->str();
auto pos = match.find_first_of('-');
std::string min = match.substr(0, pos);
std::string max = match.substr(pos + 1);
cpu_id_t cpu_min = strtol(min.data(), nullptr, 10);
cpu_id_t cpu_max = strtol(max.data(), nullptr, 10);
MS_LOG(DEBUG) << "Numa node " << numa_node << " CPU(s) : " << cpu_min << "-" << cpu_max;
for (int i = cpu_min; i <= cpu_max; ++i) {
CPU_SET(i, &cpuset);
++cpu_cnt;
}
++iter;
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!fs.bad(), "Fail to read file: " + f.toString());
fs.close();
// Remember which cpu is attached to this numa node.
numa_cpuset_.emplace(numa_node, cpuset);
numa_cpu_cnt_.emplace(numa_node, cpu_cnt);
}
MS_LOG(DEBUG) << "Number of numa nodes : " << numa_cpuset_.size();
return Status::OK();
}
Status CacheServerHW::SetAffinity(const Task &tk, numa_id_t numa_node) {
auto r = numa_cpuset_.find(numa_node);
if (r != numa_cpuset_.end()) {
auto err = pthread_setaffinity_np(tk.GetNativeHandle(), sizeof(r->second), &r->second);
if (err) {
std::string errMsg = "Unable to set affiity. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
} else {
RETURN_STATUS_UNEXPECTED("Numa node " + std::to_string(numa_node) + " not found");
}
return Status::OK();
}
std::vector<cpu_id_t> CacheServerHW::GetCpuList(numa_id_t numa_id) {
std::vector<cpu_id_t> v;
auto it = numa_cpuset_.find(numa_id);
if (it != numa_cpuset_.end()) {
auto &cpu_set = it->second;
for (auto i = 0; i < num_cpus_; ++i) {
if (CPU_ISSET(i, &cpu_set)) {
v.push_back(i);
}
}
}
return v;
}
numa_id_t CacheServerHW::GetMyNode() const {
numa_id_t node_id = 0;
auto cpu = sched_getcpu();
#ifdef NUMA_ENABLED
node_id = numa_node_of_cpu(cpu);
#else
bool found = false;
for (auto it : numa_cpuset_) {
cpu_set_t &cpu_set = it.second;
if (CPU_ISSET(cpu, &cpu_set)) {
node_id = it.first;
found = true;
break;
}
}
MS_LOG(DEBUG) << "cpu id " << cpu << " found : " << std::boolalpha << found;
#endif
return node_id;
}
void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) {
#ifdef NUMA_ENABLED
if (numa_enabled()) {
numa_interleave_memory(ptr, sz, numa_all_nodes_ptr);
}
#endif
}
bool CacheServerHW::numa_enabled() {
#ifdef NUMA_ENABLED
return (numa_available() != -1);
#else
return false;
#endif
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_
#ifdef NUMA_ENABLED
#include <numa.h>
#endif
#include <sched.h>
#include <stdlib.h>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/memory_pool.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task.h"
namespace mindspore {
namespace dataset {
class CacheServerHW {
public:
CacheServerHW();
~CacheServerHW() = default;
/// \brief Get Numa node info without using numa library
/// \return Status object
Status GetNumaNodeInfo();
/// \brief Set thread affinity
Status SetAffinity(const Task &tk, numa_id_t numa_node);
/// \brief Get total number of cpu(s)
int32_t GetCpuCount() const { return num_cpus_; }
/// \brief Get total number of numa nodes
int32_t GetNumaNodeCount() const { return numa_cpuset_.empty() ? 1 : numa_cpuset_.size(); }
/// \brief Get a list of cpu for a given numa node.
std::vector<cpu_id_t> GetCpuList(numa_id_t numa_id);
static bool numa_enabled();
/// \brief Return the numa the current thread is running on.
numa_id_t GetMyNode() const;
/// \brief Interleave a given memory block. Used by shared memory only.
static void InterleaveMemory(void *ptr, size_t sz);
/// \brief Set default memory policy.
static Status SetDefaultMemoryPolicy(CachePoolPolicy);
/// \brief This returns the size (in bytes) of the physical RAM on the machine.
/// \return the size (in bytes) of the physical RAM on the machine.
static int64_t GetTotalSystemMemory();
private:
constexpr static char kSysNodePath[] = "/sys/devices/system/node";
int32_t num_cpus_;
std::map<numa_id_t, cpu_set_t> numa_cpuset_;
std::map<numa_id_t, int32_t> numa_cpu_cnt_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_

@ -54,6 +54,8 @@ ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::
#endif
try {
rq->set_type(static_cast<int16_t>(type));
rq->set_client_id(-1);
rq->set_flag(0);
grpc::ChannelArguments args;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;

@ -0,0 +1,224 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <iterator>
#include <limits>
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
NumaMemoryPool::NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio)
: hw_(std::move(hw)), memory_cap_ratio_(memory_cap_ratio) {
int64_t total_avail = 0;
// We will create a number of small Arenas to spread out the server threads so it
// will be less contention. If we link with the numa library, i.e. if
// NUMA_ENABLED is defined, we will make use of the low level numa library such that
// each Arena solely comes from one particular socket.
// The total number of Arenas will be controlled under the number of cpus.
auto num_cpus = hw_->GetCpuCount();
memory_segments_.reserve(num_cpus);
arena_list_.reserve(num_cpus);
mux_ = std::make_unique<std::mutex[]>(num_cpus);
auto num_memory_nodes = num_cpus;
int64_t max_avail = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
int64_t arena_sz = max_avail / num_memory_nodes;
// If arena_sz is too small, lower the number of Arenas.
if (arena_sz < std::numeric_limits<int32_t>::max()) {
arena_sz = round_up_4K(std::numeric_limits<int32_t>::max());
num_memory_nodes = max_avail / arena_sz;
if (num_memory_nodes == 0) {
num_memory_nodes = 1;
arena_sz = max_avail;
}
}
MS_LOG(INFO) << "Creating " << num_memory_nodes << " number of arena. Each one of size " << arena_sz;
#ifdef NUMA_ENABLED
if (numa_available() != -1) {
auto num_numa_nodes = hw_->GetNumaNodeCount();
numa_id_t node_id = 0;
for (auto i = 0; i < num_memory_nodes; ++i) {
auto success = CreateMultipleArenas(arena_sz, node_id++ % num_numa_nodes, 1);
total_avail += success * arena_sz;
}
} else {
auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes);
total_avail += success * arena_sz;
}
#else
auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes);
total_avail += success * arena_sz;
#endif
memory_cap_ = total_avail;
MS_LOG(WARNING) << "Memory pool created. Total available memory " << memory_cap_ << " spread in " << nodes_.size()
<< " arenas";
int32_t slot = 0;
// Set up a map for future easy access.
for (auto node_id : nodes_) {
numa_map_[node_id].push_back(slot);
++slot;
}
}
int32_t NumaMemoryPool::CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count) {
int32_t success = 0;
for (auto i = 0; i < repeat_count; ++i) {
#ifdef NUMA_ENABLED
void *ptr = numa_alloc_onnode(segment_sz, node_id);
#else
void *ptr = malloc(segment_sz);
#endif
if (ptr != nullptr) {
memory_segments_.emplace_back(ptr, segment_sz);
arena_list_.push_back(std::make_unique<ArenaImpl>(ptr, segment_sz));
nodes_.push_back(node_id);
++success;
} else {
// Skip the rest.
break;
}
}
MS_LOG(DEBUG) << "Allocate " << success << " arenas from node " << node_id;
return success;
}
NumaMemoryPool::~NumaMemoryPool() {
if (!memory_segments_.empty()) {
for (auto &s : memory_segments_) {
#ifdef NUMA_ENABLED
numa_free(s.first, s.second);
#else
free(s.first);
#endif
}
}
}
Status NumaMemoryPool::Allocate(size_t n, void **p) {
RETURN_UNEXPECTED_IF_NULL(p);
auto mt = GetRandomDevice();
Status rc;
void *ptr = nullptr;
auto num_segments = memory_segments_.size();
CHECK_FAIL_RETURN_UNEXPECTED(num_segments > 0, "No numa nodes available");
if (NumaAware()) {
auto num_numa_nodes = hw_->GetNumaNodeCount();
// We will start from the numa node this worker id is running on and do a round robin search.
numa_id_t start = hw_->GetMyNode();
numa_id_t node_id = start;
do {
auto it = numa_map_.find(node_id);
if (it != numa_map_.end()) {
auto &slots = it->second;
auto num_slots = slots.size();
std::uniform_int_distribution<int32_t> distribution(0, num_slots - 1);
auto start_slot = distribution(mt);
int32_t inx = start_slot;
do {
int32_t k = slots.at(inx);
std::unique_lock lock_x(mux_[k]);
auto &impl = arena_list_.at(k);
rc = impl->Allocate(n, &ptr);
if (rc.IsOk()) {
*p = ptr;
break;
} else if (rc.IsOutofMemory()) {
inx = (inx + 1) % num_slots;
} else {
return rc;
}
} while (inx != start_slot);
}
// We have done searching for this numa node. If not found, move to the next node.
if (ptr == nullptr) {
node_id = (node_id + 1) % num_numa_nodes;
} else {
break;
}
} while (node_id != start);
} else {
// If not numa aware, just randomly pick a slot.
std::uniform_int_distribution<int32_t> distribution(0, num_segments - 1);
auto start_slot = distribution(mt);
int32_t slot = start_slot;
do {
std::unique_lock lock_x(mux_[slot]);
auto &impl = arena_list_.at(slot);
rc = impl->Allocate(n, &ptr);
if (rc.IsOk()) {
*p = ptr;
break;
} else if (rc.IsOutofMemory()) {
// Make the next arena and continue.
slot = (slot + 1) % num_segments;
} else {
return rc;
}
} while (slot != start_slot);
}
// Handle the case we have done one round robin search.
if (ptr == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
return rc;
}
void NumaMemoryPool::Deallocate(void *p) {
// Find out which numa slot it comes from.
auto slot = Locate(p);
MS_ASSERT(slot != -1);
std::unique_lock lock_x(mux_[slot]);
auto &impl = arena_list_.at(slot);
impl->Deallocate(p);
}
int NumaMemoryPool::PercentFree() const {
int percent_free = 0;
int num_arena = 0;
for (auto const &p : arena_list_) {
percent_free += p->PercentFree();
num_arena++;
}
if (num_arena) {
return percent_free / num_arena;
} else {
return 100;
}
}
int32_t NumaMemoryPool::Locate(void *p) const {
int32_t slot = 0;
char *mem = reinterpret_cast<char *>(p);
for (slot = 0; slot < memory_segments_.size(); ++slot) {
auto elem = memory_segments_.at(slot);
char *q = reinterpret_cast<char *>(elem.first);
if (mem >= q && mem < q + elem.second) {
return slot;
}
}
return -1;
}
std::vector<numa_id_t> NumaMemoryPool::GetAvailableNodes() const {
std::vector<numa_id_t> v;
std::transform(numa_map_.begin(), numa_map_.end(), std::back_inserter(v),
[](const std::pair<numa_id_t, std::vector<int32_t>> &v) { return v.first; });
return v;
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,195 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/memory_pool.h"
namespace mindspore {
namespace dataset {
/// \brief An allocator but for a particular numa node.
template <typename T>
class NumaAllocator {
public:
explicit NumaAllocator(numa_id_t node_id, CachePoolPolicy policy)
: policy_(policy), numa_enabled_(false), node_id_(node_id) {
#ifdef NUMA_ENABLED
numa_enabled_ = numa_available() != -1;
#endif
}
~NumaAllocator() = default;
template <typename U>
explicit NumaAllocator(NumaAllocator<U> const &rhs)
: policy_(rhs.policy_), numa_enabled_(rhs.numa_enabled_), node_id_(rhs.node_id_) {}
template <typename U>
bool operator==(Allocator<U> const &rhs) const {
return node_id_ == rhs.node_id_;
}
template <typename U>
bool operator!=(Allocator<U> const &rhs) const {
return node_id_ != rhs.node_id_;
}
template <typename U>
friend class NumaAllocator;
using value_type = T;
using pointer = T *;
using const_pointer = const T *;
using reference = T &;
using const_reference = const T &;
using size_type = uint64_t;
using difference_type = std::ptrdiff_t;
template <typename U>
struct rebind {
using other = Allocator<U>;
};
using propagate_on_container_copy_assignment = std::true_type;
using propagate_on_container_move_assignment = std::true_type;
using propagate_on_container_swap = std::true_type;
/// Allocate memory on this node only. Return nullptr if no memory on this numa node.
/// \note. This version will not throw if we can't allocate memory from this node.
/// User must check if the pointer returned is null or not.
pointer allocate(std::size_t n) noexcept {
auto sz = n * sizeof(T);
void *p = nullptr;
#ifdef NUMA_ENABLED
if (numa_enabled_) {
switch (policy_) {
case kPreferred:
numa_set_preferred(node_id_);
p = numa_alloc(sz);
break;
case kLocal:
p = numa_alloc_local(sz);
break;
case kInterleave:
p = numa_alloc_interleaved(sz);
break;
case kOnNode:
p = numa_alloc_onnode(sz, node_id_);
break;
case kNone:
default:
p = numa_alloc(sz);
break;
}
} else {
p = malloc(sz);
}
#else
p = malloc(sz);
#endif
return reinterpret_cast<pointer>(p);
}
/// Free a memory allocated on this node.
void deallocate(pointer p, std::size_t n) noexcept {
#ifdef NUMA_ENABLED
if (numa_enabled_) {
numa_free(p, n * sizeof(T));
} else {
free(p);
}
#else
free(p);
#endif
}
/// \brief Allow one to change to another numa node
void SetNodeId(numa_id_t node_id) { node_id_ = node_id; }
/// \brif Getter for node_id;
numa_id_t GetNodeId() const { return node_id_; }
/// \brief Getter for policy
CachePoolPolicy GetPolicy() const { return policy_; }
private:
CachePoolPolicy policy_;
bool numa_enabled_;
numa_id_t node_id_;
};
/// \brief A NumaMemoryPool is like a CircularPool but all the arenas have already been allocated
/// and each one comes from a numa socket. Memory is allocated using OnNode policy. That is,
/// it is solely comes from one particular numa node, and is not interleaved.
class NumaMemoryPool : public MemoryPool {
public:
explicit NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio);
~NumaMemoryPool() override;
// As a derived class, we override the following functions
Status Allocate(size_t size, void **pVoid) override;
void Deallocate(void *pVoid) override;
Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override { RETURN_STATUS_UNEXPECTED("Not supported"); }
uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); }
int PercentFree() const override;
/// \brief Return if the memory pool is numa aware
bool NumaAware() const { return CacheServerHW::numa_enabled(); }
/// \brief. This returns all the numa nodes that we are able to allocate memory from.
std::vector<numa_id_t> GetAvailableNodes() const;
/// \brief. Given a pointer (allocated from this pool), return the numa node where it is located.
/// \note. -1 is returned if not found.
numa_id_t FindNode(void *p) const {
auto slot = Locate(p);
if (slot != -1) {
return nodes_.at(slot);
} else {
return -1;
}
}
/// \brief Return maximum available memory
int64_t GetAvailableMemory() const { return memory_cap_; }
private:
std::shared_ptr<CacheServerHW> hw_;
float memory_cap_ratio_;
int64_t memory_cap_;
std::vector<std::pair<void *, int64_t>> memory_segments_;
std::vector<std::unique_ptr<ArenaImpl>> arena_list_;
std::unique_ptr<std::mutex[]> mux_;
std::vector<numa_id_t> nodes_;
std::map<numa_id_t, std::vector<int32_t>> numa_map_;
/// \brief. Returns the slot that a given memory comes from.
/// \return slot from numa_segments. -1 if not found.
int32_t Locate(void *p) const;
/// If numa library is not linked, or numa_availble() return -1, we will fall back to this method.
int32_t CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count);
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_

@ -15,18 +15,14 @@
*/
#include <algorithm>
#include "utils/ms_utils.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/engine/cache/cache_pool.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/services.h"
namespace mindspore {
namespace dataset {
CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root)
: alloc_(alloc),
root_(root),
subfolder_(Services::GetUniqueID()),
sm_(nullptr),
tree_(nullptr),
custom_arena_(ourOwnArena) {}
CachePool::CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root)
: mp_(std::move(mp)), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {}
Status CachePool::DoServiceStart() {
tree_ = std::make_shared<data_index>();
@ -36,10 +32,11 @@ Status CachePool::DoServiceStart() {
RETURN_IF_NOT_OK(spill.CreateDirectories());
sm_ = std::make_shared<StorageManager>(spill);
RETURN_IF_NOT_OK(sm_->ServiceStart());
MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString());
MS_LOG(INFO) << "CachePool will use disk folder: " << spill.toString();
}
return Status::OK();
}
Status CachePool::DoServiceStop() {
Status rc;
Status rc2;
@ -50,14 +47,14 @@ Status CachePool::DoServiceStop() {
}
}
sm_.reset();
// If it is our own arena, skip freeing individual pieces.
if (!custom_arena_) {
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, bl.sz);
}
value_allocator alloc(mp_);
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc.deallocate(bl.ptr, bl.sz);
}
}
tree_.reset();
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
@ -75,8 +72,10 @@ Status CachePool::DoServiceStop() {
}
return rc2;
}
CachePool::~CachePool() noexcept { (void)ServiceStop(); }
Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) {
Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf) {
DataLocator bl;
Status rc;
size_t sz = 0;
@ -85,26 +84,35 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
sz += v.GetSize();
}
bl.sz = sz;
try {
if (!writeToDiskDirectly) {
bl.ptr = alloc_.allocate(sz);
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
break;
}
pos += v.GetSize();
}
rc = mp_->Allocate(sz, reinterpret_cast<void **>(&bl.ptr));
if (rc.IsOk()) {
// Write down which numa node where we allocate from. It only make sense if the policy is kOnNode.
if (CacheServerHW::numa_enabled()) {
auto &cs = CacheServer::GetInstance();
auto node_id = cs.GetHWControl()->GetMyNode();
bl.node_id = mp_->FindNode(bl.ptr);
CHECK_FAIL_RETURN_UNEXPECTED(bl.node_id != -1, "Allocator is not from numa memory pool");
bl.node_hit = (bl.node_id == node_id);
}
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
alloc_.deallocate(bl.ptr, sz);
bl.ptr = nullptr;
return rc;
break;
}
} else if (sm_ != nullptr) {
pos += v.GetSize();
}
if (rc.IsError()) {
mp_->Deallocate(bl.ptr);
bl.ptr = nullptr;
return rc;
}
} else if (rc.IsOutofMemory()) {
// If no memory, write to disk.
if (sm_ != nullptr) {
MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes.";
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
} else {
@ -112,12 +120,8 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
// instead.
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
} catch (std::bad_alloc &e) {
if (sm_ != nullptr) {
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
} else {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
} else {
return rc;
}
// Insert into the B+ tree. We may still get out of memory error. So need to catch it.
try {
@ -127,10 +131,13 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
}
// Duplicate key is treated as error and we will also free the memory.
if (rc.IsError() && bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, sz);
mp_->Deallocate(bl.ptr);
bl.ptr = nullptr;
return rc;
}
return rc;
}
Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = tree_->Search(key);
@ -156,13 +163,14 @@ Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *byt
}
return Status::OK();
}
const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; }
Path CachePool::GetSpillPath() const {
auto spill = Path(root_) / subfolder_;
return spill;
}
CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
CacheStat cs{-1, -1, 0, 0, 0};
CacheStat cs{-1, -1, 0, 0, 0, 0};
int64_t total_sz = 0;
if (tree_->begin() != tree_->end()) {
cs.min_key = tree_->begin().key();
@ -174,6 +182,9 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
} else {
++cs.num_disk_cached;
}
if (it.value().node_hit) {
++cs.num_numa_hit;
}
auto cur_key = it.key();
if (GetMissingKeys) {
for (auto i = cs.max_key + 1; i < cur_key; ++i) {
@ -192,49 +203,26 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
}
return cs;
}
Status CachePool::Spill(CachePool::DataLocator *dl) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to spill");
}
RETURN_UNEXPECTED_IF_NULL(dl);
RETURN_UNEXPECTED_IF_NULL(dl->ptr);
if (dl->storage_key == 0) {
ReadableSlice data(dl->ptr, dl->sz);
RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data}));
}
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return Status::OK();
}
Status CachePool::Locate(CachePool::DataLocator *dl) {
RETURN_UNEXPECTED_IF_NULL(dl);
if (dl->ptr == nullptr) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to locate the data");
}
try {
dl->ptr = alloc_.allocate(dl->sz);
WritableSlice dest(dl->ptr, dl->sz);
Status rc = Read(dl->storage_key, &dest);
if (rc.IsError()) {
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return rc;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
return Status::OK();
}
size_t CachePool::GetSize(CachePool::key_type key) const {
Status CachePool::GetDataLocator(key_type key, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
flatbuffers::Offset<DataLocatorMsg> *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
return it->sz;
DataLocatorMsgBuilder bld(*fbb);
bld.add_key(key);
bld.add_size(it->sz);
bld.add_node_id(it->node_id);
bld.add_addr(reinterpret_cast<int64_t>(it->ptr));
auto offset = bld.Finish();
*out = offset;
} else {
return 0;
// Key not in the cache.
auto offset = CreateDataLocatorMsg(*fbb, key, 0, 0, 0);
*out = offset;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -19,11 +19,14 @@
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/engine/cache/storage_manager.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/storage_manager.h"
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/util/btree.h"
@ -45,13 +48,15 @@ class CachePool : public Service {
// An internal class to locate the whereabouts of a backed up buffer which can be either in
class DataLocator {
public:
DataLocator() : ptr(nullptr), sz(0), storage_key(0) {}
DataLocator() : ptr(nullptr), sz(0), node_id(0), node_hit(false), storage_key(0) {}
~DataLocator() = default;
DataLocator(const DataLocator &other) = default;
DataLocator &operator=(const DataLocator &other) = default;
DataLocator(DataLocator &&other) noexcept {
ptr = other.ptr;
sz = other.sz;
node_id = other.node_id;
node_hit = other.node_hit;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
@ -61,6 +66,8 @@ class CachePool : public Service {
if (&other != this) {
ptr = other.ptr;
sz = other.sz;
node_id = other.node_id;
node_hit = other.node_hit;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
@ -70,6 +77,8 @@ class CachePool : public Service {
}
pointer ptr;
size_t sz;
numa_id_t node_id; // where the numa node the memory is allocated to
bool node_hit; // we can allocate to the preferred node
StorageManager::key_type storage_key;
};
@ -85,19 +94,20 @@ class CachePool : public Service {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t average_cache_sz;
int64_t num_numa_hit;
std::vector<key_type> gap;
};
/// \brief Constructor
/// \param alloc Allocator to allocate memory from
/// \param root Optional disk folder to spill
explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = "");
explicit CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root = "");
CachePool(const CachePool &) = delete;
CachePool(CachePool &&) = delete;
CachePool &operator=(const CachePool &) = delete;
CachePool &operator=(CachePool &&) = delete;
~CachePool() noexcept;
~CachePool() noexcept override;
Status DoServiceStart() override;
Status DoServiceStop() override;
@ -110,7 +120,8 @@ class CachePool : public Service {
/// \param[in] buf A sequence of ReadableSlice objects.
/// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory
/// \return Error code
Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly);
Status Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf);
/// \brief Restore a cached buffer (from memory or disk)
/// \param[in] key A previous key returned from Insert
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
@ -118,18 +129,14 @@ class CachePool : public Service {
/// \return Error code
Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const;
Status Spill(DataLocator *dl);
Status Locate(DataLocator *dl);
size_t GetSize(key_type key) const;
/// \brief Serialize a DataLocator
Status GetDataLocator(key_type, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &,
flatbuffers::Offset<DataLocatorMsg> *) const;
/// \brief Get statistics.
/// \return CacheStat object
CacheStat GetStat(bool GetMissingKeys = false) const;
const value_allocator &get_allocator() const;
std::string MyName() const { return subfolder_; }
/// \brief Toggle locking
@ -137,12 +144,11 @@ class CachePool : public Service {
void SetLocking(bool on_off) { tree_->SetLocking(on_off); }
private:
value_allocator alloc_;
std::shared_ptr<NumaMemoryPool> mp_;
Path root_;
const std::string subfolder_;
std::shared_ptr<StorageManager> sm_;
std::shared_ptr<data_index> tree_;
bool custom_arena_;
};
} // namespace dataset
} // namespace mindspore

@ -14,6 +14,11 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_request.h"
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
#include <sched.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#include <cstdlib>
#include <thread>
#include "minddata/dataset/core/constants.h"
@ -106,6 +111,7 @@ Status CacheRowRequest::PostReply() {
}
return Status::OK();
}
Status CacheRowRequest::Prepare() {
if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
// First one is cookie, followed by address and then size.
@ -118,10 +124,21 @@ Status CacheRowRequest::Prepare() {
return Status::OK();
}
BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id,
bool local_bypass)
: BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) {
rq_.set_connection_id(connection_id);
CacheRowRequest::CacheRowRequest(const CacheClient *cc)
: BaseRequest(RequestType::kCacheRow),
support_local_bypass_(cc->local_bypass_),
addr_(-1),
sz_(0),
row_id_from_server_(-1) {
rq_.set_connection_id(cc->server_connection_id_);
rq_.set_client_id(cc->client_id_);
rq_.add_buf_data(cc->cookie_);
}
BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id)
: BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) {
rq_.set_connection_id(cc->server_connection_id_);
rq_.set_client_id(cc->client_id_);
rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
// Convert the row id into a flatbuffer
flatbuffers::FlatBufferBuilder fbb;
@ -186,9 +203,9 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, in
return Status::OK();
}
CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheRequest::CreateCacheFlag flag)
: BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) {
: BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) {
// Type has been set already in the base constructor. So we need to fill in the connection info.
// On successful return, we will get the connection id
rq_.mutable_connection_info()->operator=(cinfo);
@ -209,6 +226,41 @@ Status CreateCacheRequest::Prepare() {
}
}
Status CreateCacheRequest::PostReply() {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
cc_->server_connection_id_ = p->connection_id();
cc_->cookie_ = p->cookie()->str();
cc_->client_id_ = p->client_id();
// Next is a set of cpu id that we should re-adjust ourselves for better affinity.
auto sz = p->cpu_id()->size();
cc_->cpu_list_.reserve(sz);
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
std::string c_list;
cpu_set_t cpu_set;
CPU_ZERO(&cpu_set);
#endif
for (auto i = 0; i < sz; ++i) {
auto cpu_id = p->cpu_id()->Get(i);
cc_->cpu_list_.push_back(cpu_id);
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
c_list += std::to_string(cpu_id) + " ";
CPU_SET(cpu_id, &cpu_set);
#endif
}
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
if (sz > 0) {
auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set);
if (err == -1) {
RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno));
}
MS_LOG(WARNING) << "Changing cpu affinity to the following list of cpu id: " + c_list;
}
#endif
return Status::OK();
}
Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
try {
flatbuffers::FlatBufferBuilder fbb;
@ -245,6 +297,7 @@ Status GetStatRequest::PostReply() {
stat_.num_disk_cached = msg->num_disk_cached();
stat_.num_mem_cached = msg->num_mem_cached();
stat_.avg_cache_sz = msg->avg_cache_sz();
stat_.num_numa_hit = msg->num_numa_hit();
stat_.max_row_id = msg->max_row_id();
stat_.min_row_id = msg->min_row_id();
stat_.cache_service_state = msg->state();
@ -255,14 +308,15 @@ Status ListSessionsRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
auto session_vector = msg->sessions();
for (auto i = 0; i < session_vector->size(); ++i) {
SessionCacheInfo current_info;
CacheServiceStat stats;
SessionCacheInfo current_info{};
CacheServiceStat stats{};
auto current_session_info = session_vector->Get(i);
current_info.session_id = current_session_info->session_id();
current_info.connection_id = current_session_info->connection_id();
stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
stats.num_numa_hit = current_session_info->stats()->num_numa_hit();
stats.min_row_id = current_session_info->stats()->min_row_id();
stats.max_row_id = current_session_info->stats()->max_row_id();
stats.cache_service_state = current_session_info->stats()->state();

@ -41,6 +41,7 @@ struct CacheServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t avg_cache_sz;
int64_t num_numa_hit;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
@ -75,6 +76,8 @@ class BaseRequest {
kHeartBeat = 14,
kToggleWriteMode = 15,
kListSessions = 16,
kConnectReset = 17,
kInternalFetchRow = 18,
// Add new request before it.
kRequestUnknown = 32767
};
@ -84,10 +87,15 @@ class BaseRequest {
friend class CacheClientGreeter;
friend class CacheClientRequestTag;
friend class CacheClient;
friend class CacheService;
/// \brief Base class of a cache server request
/// \param type Type of the request
explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<int16_t>(type_)); }
explicit BaseRequest(RequestType type) : type_(type) {
rq_.set_type(static_cast<int16_t>(type_));
rq_.set_client_id(-1);
rq_.set_flag(0);
}
virtual ~BaseRequest() = default;
/// \brief A print method for debugging
@ -138,15 +146,7 @@ class CacheRowRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheClient;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass)
: BaseRequest(RequestType::kCacheRow),
support_local_bypass_(local_bypass),
addr_(-1),
sz_(0),
row_id_from_server_(-1) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie);
}
explicit CacheRowRequest(const CacheClient *cc);
~CacheRowRequest() override = default;
/// \brief Serialize a TensorRow for streaming to the cache server
@ -193,7 +193,7 @@ class BatchFetchRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass);
BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id);
~BatchFetchRequest() override = default;
Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);
@ -212,21 +212,18 @@ class CreateCacheRequest : public BaseRequest {
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
explicit CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone);
~CreateCacheRequest() override = default;
void ParseResult(connection_id_type *id, std::string *out) {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
*id = p->connection_id();
*out = p->cookie()->str();
}
/// Overload the base class Prepare
/// Overload the base class Prepare/PostReply
Status Prepare() override;
Status PostReply() override;
private:
uint64_t cache_mem_sz_;
CreateCacheFlag flag_;
CacheClient *cc_;
};
/// \brief Request to get all the keys not present at the server.
@ -396,6 +393,23 @@ class ToggleWriteModeRequest : public BaseRequest {
}
~ToggleWriteModeRequest() override = default;
};
class ConnectResetRequest : public BaseRequest {
public:
friend class CacheServer;
explicit ConnectResetRequest(connection_id_type connection_id, int32_t client_id)
: BaseRequest(RequestType::kConnectReset) {
rq_.set_connection_id(connection_id);
rq_.set_client_id(client_id);
}
~ConnectResetRequest() override = default;
/// Override the base class function
Status Prepare() override {
CHECK_FAIL_RETURN_UNEXPECTED(rq_.client_id() != -1, "Invalid client id");
return Status::OK();
}
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_

File diff suppressed because it is too large Load Diff

@ -17,23 +17,31 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <algorithm>
#include <atomic>
#include <chrono>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include <set>
#include <thread>
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include "minddata/dataset/engine/cache/cache_pool.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/lock.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/semaphore.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/system_pool.h"
@ -47,9 +55,10 @@ class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
class Builder {
public:
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {}
Builder();
~Builder() = default;
@ -161,26 +170,40 @@ class CacheServer : public Service {
/// \return Status object
static Status ReturnRequestTag(CacheServerRequest *p);
/// \brief This returns the size (in bytes) of the physical RAM on the machine.
/// \return the size (in bytes) of the physical RAM on the machine.
static int64_t GetTotalSystemMemory();
/// Return an instance of the numa control
std::shared_ptr<CacheServerHW> GetHWControl() { return hw_info_; }
/// \brief Internally this is how much we will try to use without exceeding the limit
/// \return Internal cap maximum
int64_t GetAvailableSystemMemory() { return memory_cap_; }
/// \brief Set CPU affinity
Status SetAffinity(const Task &tk, numa_id_t numa_node) { return hw_info_->SetAffinity(tk, numa_node); }
/// \brief Find out the current memory usage
int64_t GetMemoryUsage() { return cur_mem_usage_; }
/// \brief return number of workers
auto GetNumWorkers() const { return num_workers_; }
/// \brief This updates our current memory usage.
enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 };
void UpdateMemoryUsage(int64_t sz, MemUsageOp op) {
if (op == MemUsageOp::kAllocate) {
cur_mem_usage_ += sz;
} else {
cur_mem_usage_ -= sz;
}
}
/// \brief return number of grpc workers
auto GetNumGrpcWorkers() const { return num_grpc_workers_; }
/// \brief return number of numa nodes
auto GetNumaNodeCount() const { return hw_info_->GetNumaNodeCount(); }
/// \brief Assign a worker by a numa id
/// \return worker id
worker_id_t GetWorkerByNumaId(numa_id_t node_id);
/// \brief Randomly pick a worker
/// \return worker id
worker_id_t GetRandomWorker();
/// \brief Check if we bind threads to numa cores
bool IsNumaAffinityOn() const { return numa_affinity_; }
/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
/// \brief Return the memory cap ratio
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
private:
static std::once_flag init_instance_flag_;
@ -189,20 +212,21 @@ class CacheServer : public Service {
mutable RWLock sessions_lock_;
std::string top_;
cache_index all_caches_;
std::map<session_id_type, std::set<connection_id_type>> active_sessions_;
std::set<session_id_type> active_sessions_;
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>> tag_;
std::shared_ptr<CacheServerGreeterImpl> comm_layer_;
std::shared_ptr<MemoryPool> mp_;
TaskGroup vg_;
int32_t num_workers_;
int32_t num_grpc_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
std::atomic<bool> global_shutdown_;
float memory_cap_ratio_;
int64_t memory_cap_;
std::atomic<int64_t> cur_mem_usage_;
std::shared_ptr<CacheServerHW> hw_info_;
std::map<worker_id_t, Task *> numa_tasks_;
bool numa_affinity_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
@ -226,11 +250,11 @@ class CacheServer : public Service {
Status DestroyCache(CacheRequest *rq);
/// \brief Entry point for all internal server threads.
Status ServerRequest(int32_t worker_id);
Status ServerRequest(worker_id_t worker_id);
/// \brief Entry point for all grpc threads.
/// \return
Status RpcRequest(int32_t worker_id);
Status RpcRequest(worker_id_t worker_id);
Status DestroySession(CacheRequest *rq);
@ -266,12 +290,6 @@ class CacheServer : public Service {
Status FastCacheRow(CacheRequest *rq, CacheReply *reply);
Status CacheRow(CacheRequest *rq, CacheReply *reply);
/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
/// \brief Internal function to get statistics
/// \param rq
/// \param reply
@ -309,6 +327,9 @@ class CacheServer : public Service {
/// \param reply
/// \return Status object
Status ListSessions(CacheReply *reply);
/// \brief Connect request by a pipeline
Status ConnectReset(CacheRequest *rq);
};
} // namespace dataset
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -29,36 +29,28 @@
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_pool.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/btree.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/system_pool.h"
namespace mindspore {
namespace dataset {
/// Some typedef used for BatchFetch
using key_size_pair = std::pair<CachePool::key_type, size_t>;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class CacheService : public Service {
public:
friend class CacheServer;
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };
/// \brief Constructor
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
/// \param root Spill path. Empty string means no spilling
/// \param generate_id If the cache service should generate row id for buffer that is cached.
/// For non-mappable dataset, this should be set to true.
CacheService(uint64_t mem_sz, const std::string &root, bool generate_id);
~CacheService();
/// \brief For fixed size memory, we will create an Arena.
/// \return false if unlimited memory.
bool UseArena();
~CacheService() override;
Status DoServiceStart() override;
Status DoServiceStop() override;
@ -77,18 +69,18 @@ class CacheService : public Service {
Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated);
/// \brief This function is used in preparation for batch fetching.
/// It calculates how much memory we should allocate and which row id are present.
/// \param[in/out] Pointer to vector of <CachePool::key_type, size_t>
/// \param[in/out] mem_sz how much memory is required to batch fetch
/// It calculates how much memory we should allocate and which row id are present, etc.
/// All needed results are stored in the flat buffer.
/// \return Status object
Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz);
Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const;
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, WritableSlice *out) const;
/// \brief Getter function
/// \return Spilling path
@ -96,7 +88,7 @@ class CacheService : public Service {
/// \brief A structure returned from the cache server for statistics request.
class ServiceStat {
public:
using state_type = std::underlying_type<State>::type;
using state_type = std::underlying_type<CacheServiceState>::type;
ServiceStat() : state_(0) {}
~ServiceStat() = default;
CachePool::CacheStat stat_{};
@ -134,10 +126,6 @@ class CacheService : public Service {
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
/// \return Status object
Status BuildPhaseDone();
/// \brief Find out the current memory usage
int64_t GetMemoryUsage() { return cur_mem_usage_; }
/// \brief Find out the current disk usage
int64_t GetDiskUsage() { return cur_disk_usage_; }
/// \brief For kToggleWriteMode request
Status ToggleWriteMode(bool on_off);
@ -149,14 +137,10 @@ class CacheService : public Service {
std::atomic<row_id_type> next_id_;
bool generate_id_;
std::string cookie_;
State st_;
std::atomic<int32_t> num_clients_;
CacheServiceState st_;
std::string schema_;
// If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it.
// Otherwise we track how much is in memory and how much is on disk (if root_ is not empty).
// We use them to control when we should stop caching in memory in the case when there is no
// Arena.
std::atomic<int64_t> cur_mem_usage_;
std::atomic<int64_t> cur_disk_usage_;
std::shared_ptr<NumaMemoryPool> numa_pool_;
// We also cache the result from calling FindKeysMiss because it is expensive. Besides user make
// this request after we hit memory full or disk full. So the result is unlikely to change.
std::mutex get_key_miss_mux_;
@ -164,6 +148,8 @@ class CacheService : public Service {
/// \brief Private function to generate a row id
/// \return Row id assigned.
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }
Status InternalFetchRow(const FetchRowMsg *p);
};
} // namespace dataset
} // namespace mindspore

@ -65,6 +65,7 @@ table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
avg_cache_sz:int64;
num_numa_hit:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
@ -89,8 +90,10 @@ table CreateCacheRequestMsg {
/// Return result of CreateCacheRequest
table CreateCacheReplyMsg {
connection_id:int64;
client_id:int32;
connection_id:uint64;
cookie:string;
cpu_id:[int32];
}
table ListSessionMsg {
@ -102,3 +105,22 @@ table ListSessionMsg {
table ListSessionsMsg {
sessions:[ListSessionMsg];
}
table DataLocatorMsg {
key:int64;
node_id:int32;
addr:int64;
size:int64;
}
table BatchDataLocatorMsg {
connection_id:uint64;
rows:[DataLocatorMsg];
}
table FetchRowMsg {
key:int64;
source_addr:int64;
dest_addr:int64;
size:int64;
}

@ -0,0 +1,32 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_CACHE)
ms_protobuf_generate(CACHE_PERF_PROTO_SRCS CACHE_PERF_PROTO_HDRS cache_perf.proto)
add_executable(cache_perf cache_perf.cc cache_msg.cc cache_perf_run.cc ${CACHE_PERF_PROTO_SRCS})
target_link_libraries(cache_perf
_c_dataengine
_c_mindrecord
mindspore::protobuf
mindspore_gvar
${PYTHON_LIBRARIES}
pthread)
if (USE_GLOG)
target_link_libraries(cache_perf mindspore::glog)
endif ()
add_executable(cache_pipeline cache_pipeline.cc cache_msg.cc cache_pipeline_run.cc ${CACHE_PERF_PROTO_SRCS})
target_link_libraries(cache_pipeline
_c_dataengine
_c_mindrecord
mindspore::protobuf
mindspore_gvar
${PYTHON_LIBRARIES}
pthread)
if (USE_GLOG)
target_link_libraries(cache_pipeline mindspore::glog)
endif ()
endif ()

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/perf/cache_msg.h"
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/msg.h>
namespace mindspore {
namespace dataset {
Status CachePerfMsg::Send(int32_t qID) {
auto err = msgsnd(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), IPC_NOWAIT);
if (err == -1) {
std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status CachePerfMsg::Receive(int32_t qID) {
// This is a blocking call. Either there is some message or we the queue is removed when
// the destructor is called.
auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR);
if (err == -1) {
if (errno == EIDRM) {
return Status(StatusCode::kInterrupted);
} else {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save