pull/8315/head
Lixia Chen 5 years ago committed by Jesse Lee
parent 5792d50ca7
commit f24a788eed

@ -62,7 +62,7 @@ if (ENABLE_CACHE)
endif () endif ()
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread)
if (USE_GLOG) if (USE_GLOG)
target_link_libraries(cache_admin mindspore::glog) target_link_libraries(cache_admin mindspore::glog)

@ -27,6 +27,8 @@
#include <vector> #include <vector>
#include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
#include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/constants.h"
@ -325,9 +327,33 @@ Status CacheAdminArgHandler::RunCommand() {
Help(); Help();
break; break;
} }
case CommandId::kCmdStart: case CommandId::kCmdStart: {
RETURN_IF_NOT_OK(StartServer(command_id_));
break;
}
case CommandId::kCmdStop: { case CommandId::kCmdStop: {
RETURN_IF_NOT_OK(StartStopServer(command_id_)); CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
SharedMessage msg;
RETURN_IF_NOT_OK(msg.Create());
auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId());
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
Status rc = rq->Wait();
if (rc.IsError()) {
msg.RemoveResourcesOnExit();
if (rc.IsNetWorkError()) {
std::string errMsg = "Server is not up or has been shutdown already.";
return Status(StatusCode::kNetWorkError, errMsg);
}
return rc;
}
// OK return code only means the server acknowledge our request but we still
// have to wait for its complete shutdown because the server will shutdown
// the comm layer as soon as the request is received, and we need to wait
// on the message queue instead.
// The server will remove the queue and we will then wake up.
Status dummy_rc;
(void)msg.ReceiveStatus(&dummy_rc);
break; break;
} }
case CommandId::kCmdGenerateSession: { case CommandId::kCmdGenerateSession: {
@ -396,7 +422,7 @@ Status CacheAdminArgHandler::RunCommand() {
return Status::OK(); return Status::OK();
} }
Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { Status CacheAdminArgHandler::StartServer(CommandId command_id) {
// There currently does not exist any "install path" or method to identify which path the installed binaries will // There currently does not exist any "install path" or method to identify which path the installed binaries will
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
// cache_admin binary (this process). // cache_admin binary (this process).
@ -477,23 +503,15 @@ Status CacheAdminArgHandler::StartStopServer(CommandId command_id) {
std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_);
char *argv[9]; char *argv[9];
if (command_id == CommandId::kCmdStart) { argv[0] = cache_server_binary.data();
argv[0] = cache_server_binary.data(); argv[1] = spill_dir_.data();
argv[1] = spill_dir_.data(); argv[2] = workers_string.data();
argv[2] = workers_string.data(); argv[3] = port_string.data();
argv[3] = port_string.data(); argv[4] = shared_memory_string.data();
argv[4] = shared_memory_string.data(); argv[5] = minloglevel_string.data();
argv[5] = minloglevel_string.data(); argv[6] = daemonize_string.data();
argv[6] = daemonize_string.data(); argv[7] = memory_cap_ratio_string.data();
argv[7] = memory_cap_ratio_string.data(); argv[8] = nullptr;
argv[8] = nullptr;
} else {
// We are doing a --stop. Change the name to '-' and we also need the port number.
// The rest we don't need.
argv[0] = std::string("-").data();
argv[1] = port_string.data();
argv[2] = nullptr;
}
// Now exec the binary // Now exec the binary
execv(cache_server_binary.data(), argv); execv(cache_server_binary.data(), argv);
@ -509,17 +527,27 @@ void CacheAdminArgHandler::Help() {
std::cerr << "Syntax:\n"; std::cerr << "Syntax:\n";
std::cerr << " cache_admin [--start | --stop]\n"; std::cerr << " cache_admin [--start | --stop]\n";
std::cerr << " [ [-h | --hostname] <hostname> ]\n"; std::cerr << " [ [-h | --hostname] <hostname> ]\n";
std::cerr << " Default is " << kCfgDefaultCacheHost << ".\n";
std::cerr << " [ [-p | --port] <port number> ]\n"; std::cerr << " [ [-p | --port] <port number> ]\n";
std::cerr << " Possible values are in range [1025..65535].\n";
std::cerr << " Default is " << kCfgDefaultCachePort << ".\n";
std::cerr << " [ [-g | --generate_session] ]\n"; std::cerr << " [ [-g | --generate_session] ]\n";
std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; std::cerr << " [ [-d | --destroy_session] <session id> ]\n";
std::cerr << " [ [-w | --workers] <number of workers> ]\n"; std::cerr << " [ [-w | --workers] <number of workers> ]\n";
std::cerr << " Possible values are in range [1...max(100, Number of CPU)].\n";
std::cerr << " Default is " << kDefaultNumWorkers << ".\n";
std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n";
std::cerr << " Default is " << kDefaultSpillDir << ".\n";
std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; std::cerr << " [ [-l | --minloglevel] <log level> ]\n";
std::cerr << " Possible values are 0, 1, 2 and 3.\n";
std::cerr << " Default is 1 (info level).\n";
std::cerr << " [ --list_sessions ]\n"; std::cerr << " [ --list_sessions ]\n";
// Do not expose these option to the user via help or documentation, but the options do exist to aid with // Do not expose these option to the user via help or documentation, but the options do exist to aid with
// development and tuning. // development and tuning.
// std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; // std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n";
// std::cerr << " Default is " << kDefaultSharedMemorySizeInGB << " (Gb in unit).\n";
// std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n"; // std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n";
// std::cerr << " Default is " << kMemoryCapRatio << ".\n";
std::cerr << " [--help]" << std::endl; std::cerr << " [--help]" << std::endl;
} }
} // namespace dataset } // namespace dataset

@ -78,7 +78,7 @@ class CacheAdminArgHandler {
kArgNumArgs = 14 // Must be the last position to provide a count kArgNumArgs = 14 // Must be the last position to provide a count
}; };
Status StartStopServer(CommandId); Status StartServer(CommandId command_id);
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown); CommandId command_id = CommandId::kCmdUnknown);

@ -21,11 +21,7 @@ CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
// We create the shared memory and we will destroy it. All other client just detach only. // We create the shared memory and we will destroy it. All other client just detach only.
shm_.RemoveResourcesOnExit(); shm_.RemoveResourcesOnExit();
} }
CachedSharedMemoryArena::~CachedSharedMemoryArena() { CachedSharedMemoryArena::~CachedSharedMemoryArena() {}
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) { size_t val_in_GB) {

@ -72,6 +72,9 @@ class CachedSharedMemoryArena : public MemoryPool {
return os; return os;
} }
/// \brief Get the shared memory key of the shared memory
SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); }
private: private:
mutable std::mutex mux_; mutable std::mutex mux_;
int32_t val_in_GB_; int32_t val_in_GB_;

@ -13,10 +13,12 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <chrono>
#include <limits> #include <limits>
#include "minddata/dataset/engine/cache/cache_grpc_server.h" #include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include "minddata/dataset/engine/cache/cache_server.h" #include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/task_manager.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#else #else
@ -25,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) { : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) {
// Setup a path for unix socket. // Setup a path for unix socket.
unix_socket_ = PortToUnixSocketPath(port); unix_socket_ = PortToUnixSocketPath(port);
// We can't generate the ftok key yet until the unix_socket_ is created // We can't generate the ftok key yet until the unix_socket_ is created
@ -73,7 +75,8 @@ Status CacheServerGreeterImpl::Run() {
MS_LOG(INFO) << "Server listening on " << server_address; MS_LOG(INFO) << "Server listening on " << server_address;
#if CACHE_LOCAL_CLIENT #if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
MS_LOG(INFO) << "Creation of local socket and shared memory successful"; shm_key_ = shm_pool_->GetKey();
MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_;
auto cs = CacheServer::GetInstance().GetHWControl(); auto cs = CacheServer::GetInstance().GetHWControl();
// This shared memory is a hot memory and we will interleave among all the numa nodes. // This shared memory is a hot memory and we will interleave among all the numa nodes.
cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L);
@ -181,5 +184,32 @@ void CacheServerRequest::Print(std::ostream &out) const {
out << " "; out << " ";
BaseRequest::Print(out); BaseRequest::Print(out);
} }
Status CacheServerGreeterImpl::MonitorUnixSocket() {
TaskManager::FindMe()->Post();
#if CACHE_LOCAL_CLIENT
Path p(unix_socket_);
do {
RETURN_IF_INTERRUPTED();
// If the unix socket is recreated for whatever reason, this server instance will be stale and
// no other process and communicate with us. In this case we need to shutdown ourselves.
if (p.Exists()) {
SharedMemory::shm_key_t key;
RETURN_IF_NOT_OK(PortToFtok(port_, &key));
if (key != shm_key_) {
std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) +
". New key " + std::to_string(key) + ". Shutting down server";
MS_LOG(ERROR) << errMsg;
RETURN_STATUS_UNEXPECTED(errMsg);
}
} else {
MS_LOG(WARNING) << "Unix socket is removed.";
TaskManager::WakeUpWatchDog();
}
std::this_thread::sleep_for(std::chrono::seconds(5));
} while (true);
#endif
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -87,6 +87,9 @@ class CacheServerGreeterImpl final {
/// \return Return the shared memory pool /// \return Return the shared memory pool
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }
/// \brief Montor the status of the unix socket in case it is gone.
Status MonitorUnixSocket();
void Shutdown(); void Shutdown();
private: private:
@ -97,6 +100,7 @@ class CacheServerGreeterImpl final {
std::unique_ptr<grpc::ServerCompletionQueue> cq_; std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_; std::unique_ptr<grpc::Server> server_;
std::unique_ptr<CachedSharedMemoryArena> shm_pool_; std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
SharedMemory::shm_key_t shm_key_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -160,6 +160,9 @@ class SharedMemory : public BaseIPC {
/// \brief Set the public key /// \brief Set the public key
void SetPublicKey(key_t public_key) { shm_key_ = public_key; } void SetPublicKey(key_t public_key) { shm_key_ = public_key; }
/// \brief Retrieve the key
shm_key_t GetKey() const { return shm_key_; }
/// \brief This returns where we attach to the shared memory. /// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory. /// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; } const void *SharedMemoryBaseAddr() const { return shmat_addr_; }

@ -27,116 +27,6 @@
#include "minddata/dataset/engine/cache/cache_ipc.h" #include "minddata/dataset/engine/cache/cache_ipc.h"
namespace ds = mindspore::dataset; namespace ds = mindspore::dataset;
/// Send a synchronous command to the local server using tcp/ip.
/// We aren't using any client code because this binary is not necessarily linked with the client library.
/// So just using grpc call directly.
/// \param port tcp/ip port to use
/// \param type Type of command.
/// \param out grpc result
/// \return Status object
ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply,
grpc::Status *out) {
if (rq == nullptr) {
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null");
}
if (reply == nullptr) {
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null");
}
if (out == nullptr) {
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null");
}
const std::string hostname = "127.0.0.1";
auto unix_socket = ds::PortToUnixSocketPath(port);
#if CACHE_LOCAL_CLIENT
const std::string target = "unix://" + unix_socket;
#else
const std::string target = hostname + ":" + std::to_string(port);
#endif
try {
rq->set_type(static_cast<int16_t>(type));
rq->set_client_id(-1);
rq->set_flag(0);
grpc::ChannelArguments args;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;
// Standard async rpc call
std::shared_ptr<grpc::Channel> channel =
grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
std::unique_ptr<ds::CacheServerGreeter::Stub> stub = ds::CacheServerGreeter::NewStub(channel);
std::unique_ptr<grpc::ClientAsyncResponseReader<ds::CacheReply>> rpc =
stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq);
rpc->StartCall();
// We need to pass a tag. But since this is the only request in the completion queue and so we
// just pass a dummy
int64_t dummy;
void *tag;
bool success;
rpc->Finish(reply, out, &dummy);
// Now we wait on the completion queue synchronously.
auto r = cq.Next(&tag, &success);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
if (!success || tag != &dummy) {
std::string errMsg = "Unexpected programming error ";
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
if (out->ok()) {
return ds::Status(static_cast<ds::StatusCode>(reply->rc()), reply->msg());
} else {
auto error_code = out->error_code();
std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code);
return ds::Status(ds::StatusCode::kNetWorkError, errMsg);
}
} else {
std::string errMsg = "Unexpected queue rc = " + std::to_string(r);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
} catch (const std::exception &e) {
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what());
}
}
/// Stop the server
/// \param argv
/// \return Status object
ds::Status StopServer(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;
std::string errMsg;
if (argc != 2) {
return ds::Status(ds::StatusCode::kSyntaxError);
}
int32_t port = strtol(argv[1], nullptr, 10);
// We will go through the builder to do some snaity check. We only need the port number
// to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything
// to the spill directory.
builder.SetPort(port).SetRootDirectory("");
// Part of the sanity check is check the shared memory. If the server is up and running, we expect
// the return code is kDuplicate.
rc = builder.SanityCheck();
if (rc.IsOk()) {
errMsg = "Server is not up or has been shutdown already.";
return ds::Status(ds::StatusCode::kUnexpectedError, errMsg);
} else if (rc.get_code() != ds::StatusCode::kDuplicateKey) {
// Not OK, and no duplicate, just return the rc whatever it is.
return rc;
} else {
// Now we get some work to do. We will send a tcp/ip request to the given port.
// This binary is not linked with client side of code, so we will just call grpc directly.
ds::CacheRequest rq;
ds::CacheReply reply;
grpc::Status grpc_rc;
rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc);
// The request is like a self destruct message, the server will not send anything back and
// shutdown all incoming request. So we should expect some unexpected network error if
// all goes well and we expect to GRPC code 14.
auto err_code = grpc_rc.error_code();
if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) {
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__);
}
}
return ds::Status::OK();
}
/// Start the server /// Start the server
/// \param argv /// \param argv
/// \return Status object /// \return Status object
@ -235,15 +125,8 @@ ds::Status StartServer(int argc, char **argv) {
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;
// This executable is not to be called directly, and should be invoked by cache_admin executable. // This executable is not to be called directly, and should be invoked by cache_admin executable.
if (strcmp(argv[0], "-") == 0) { ds::Status rc = StartServer(argc, argv);
rc = StopServer(argc, argv);
} else {
rc = StartServer(argc, argv);
}
// Check result // Check result
if (rc.IsError()) { if (rc.IsError()) {
auto errCode = rc.get_code(); auto errCode = rc.get_code();

@ -20,6 +20,7 @@
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <cstdlib> #include <cstdlib>
#include <cstring>
#include <thread> #include <thread>
#include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_client.h"
@ -326,5 +327,11 @@ Status ListSessionsRequest::PostReply() {
return Status::OK(); return Status::OK();
} }
Status ServerStopRequest::PostReply() {
CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response");
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -394,6 +394,15 @@ class ToggleWriteModeRequest : public BaseRequest {
~ToggleWriteModeRequest() override = default; ~ToggleWriteModeRequest() override = default;
}; };
class ServerStopRequest : public BaseRequest {
public:
friend class CacheServer;
explicit ServerStopRequest(int32_t qID) : BaseRequest(RequestType::kStopService) {
rq_.add_buf_data(std::to_string(qID));
}
Status PostReply() override;
};
class ConnectResetRequest : public BaseRequest { class ConnectResetRequest : public BaseRequest {
public: public:
friend class CacheServer; friend class CacheServer;

@ -108,6 +108,9 @@ Status CacheServer::DoServiceStart() {
try { try {
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_);
RETURN_IF_NOT_OK(comm_layer_->Run()); RETURN_IF_NOT_OK(comm_layer_->Run());
// Bring up a thread to monitor the unix socket in case it is removed.
auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
} catch (const std::exception &e) { } catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what()); RETURN_STATUS_UNEXPECTED(e.what());
} }
@ -154,6 +157,15 @@ Status CacheServer::DoServiceStop() {
} }
++it; ++it;
} }
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
// Finally wake up cache_admin if it is waiting
for (int32_t qID : shutdown_qIDs_) {
SharedMessage msg(qID);
msg.RemoveResourcesOnExit();
// Let msg goes out of scope which will destroy the queue.
}
return rc; return rc;
} }
@ -374,6 +386,68 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
return rc; return rc;
} }
Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
RETURN_UNEXPECTED_IF_NULL(out);
int32_t numQ = GetNumGrpcWorkers();
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
std::vector<CacheServerRequest *> cache_rq_list;
auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
const auto num_elements = p->rows()->size();
auto connection_id = p->connection_id();
cache_rq_list.reserve(num_elements);
int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
offset_array[0] = data_offset;
for (auto i = 0; i < num_elements; ++i) {
auto data_locator = p->rows()->Get(i);
auto node_id = data_locator->node_id();
size_t sz = data_locator->size();
void *source_addr = reinterpret_cast<void *>(data_locator->addr());
auto key = data_locator->key();
// Please read the comment in CacheServer::BatchFetchRows where we allocate
// the buffer big enough so each thread (which we are going to dispatch) will
// not run into false sharing problem. We are going to round up sz to 4k.
auto sz_4k = round_up_4K(sz);
offset_array[i + 1] = offset_array[i] + sz_4k;
if (sz > 0) {
WritableSlice row_data(*out, offset_array[i], sz);
// Get a request and send to the proper worker (at some numa node) to do the fetch.
worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker();
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq));
cache_rq_list.push_back(cache_rq);
// Set up all the necessarily field.
cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
cache_rq->rq_.set_connection_id(connection_id);
cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
auto dest_addr = row_data.GetMutablePointer();
flatbuffers::FlatBufferBuilder fb2;
FetchRowMsgBuilder bld(fb2);
bld.add_key(key);
bld.add_size(sz);
bld.add_source_addr(reinterpret_cast<int64_t>(source_addr));
bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr));
auto offset = bld.Finish();
fb2.Finish(offset);
cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq));
}
}
// Now wait for all of them to come back.
Status rc;
for (CacheServerRequest *rq : cache_rq_list) {
RETURN_IF_NOT_OK(rq->Wait());
if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) {
rc = rq->rc_;
}
RETURN_IF_NOT_OK(ReturnRequestTag(rq));
}
return rc;
}
Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id(); auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped. // Hold the shared lock to prevent the cache from being dropped.
@ -394,6 +468,9 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
} }
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb));
// Let go of the shared lock. We don't need to interact with the CacheService anymore.
// We shouldn't be holding any lock while we can wait for a long time for the rows to come back.
lck.Unlock();
auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
int64_t mem_sz = sizeof(int64_t) * (sz + 1); int64_t mem_sz = sizeof(int64_t) * (sz + 1);
for (auto i = 0; i < sz; ++i) { for (auto i = 0; i < sz; ++i) {
@ -418,7 +495,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
void *q = nullptr; void *q = nullptr;
RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q));
WritableSlice dest(q, mem_sz); WritableSlice dest(q, mem_sz);
Status rc = cs->BatchFetch(fbb, &dest); Status rc = BatchFetch(fbb, &dest);
if (rc.IsError()) { if (rc.IsError()) {
shared_pool->Deallocate(q); shared_pool->Deallocate(q);
return rc; return rc;
@ -439,7 +516,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
return Status(StatusCode::kOutOfMemory); return Status(StatusCode::kOutOfMemory);
} }
WritableSlice dest(mem.data(), mem_sz); WritableSlice dest(mem.data(), mem_sz);
RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest)); RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
reply->set_result(std::move(mem)); reply->set_result(std::move(mem));
} }
} }
@ -721,7 +798,7 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) {
} }
case BaseRequest::RequestType::kStopService: { case BaseRequest::RequestType::kStopService: {
// This command shutdowns everything. // This command shutdowns everything.
cache_req->rc_ = GlobalShutdown(); cache_req->rc_ = GlobalShutdown(cache_req);
break; break;
} }
case BaseRequest::RequestType::kHeartBeat: { case BaseRequest::RequestType::kHeartBeat: {
@ -914,7 +991,25 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) {
return Status::OK(); return Status::OK();
} }
Status CacheServer::GlobalShutdown() { Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) {
auto *rq = &cache_req->rq_;
auto *reply = &cache_req->reply_;
if (!rq->buf_data().empty()) {
// cache_admin sends us a message qID and we will destroy the
// queue in our destructor and this will wake up cache_admin.
// But we don't want the cache_admin blindly just block itself.
// So we will send back an ack before shutdown the comm layer.
try {
int32_t qID = std::stoi(rq->buf_data(0));
shutdown_qIDs_.push_back(qID);
} catch (const std::exception &e) {
// ignore it.
}
}
reply->set_result("OK");
Status2CacheReply(cache_req->rc_, reply);
cache_req->st_ = CacheServerRequest::STATE::FINISH;
cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req);
// Let's shutdown in proper order. // Let's shutdown in proper order.
bool expected = false; bool expected = false;
if (global_shutdown_.compare_exchange_strong(expected, true)) { if (global_shutdown_.compare_exchange_strong(expected, true)) {
@ -939,7 +1034,7 @@ Status CacheServer::GlobalShutdown() {
return Status::OK(); return Status::OK();
} }
worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const {
auto num_numa_nodes = GetNumaNodeCount(); auto num_numa_nodes = GetNumaNodeCount();
MS_ASSERT(numa_id < num_numa_nodes); MS_ASSERT(numa_id < num_numa_nodes);
auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; auto num_workers_per_node = GetNumWorkers() / num_numa_nodes;
@ -951,7 +1046,7 @@ worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) {
return worker_id; return worker_id;
} }
worker_id_t CacheServer::GetRandomWorker() { worker_id_t CacheServer::GetRandomWorker() const {
std::mt19937 gen = GetRandomDevice(); std::mt19937 gen = GetRandomDevice();
std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1); std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1);
return dist(gen); return dist(gen);

@ -187,11 +187,11 @@ class CacheServer : public Service {
/// \brief Assign a worker by a numa id /// \brief Assign a worker by a numa id
/// \return worker id /// \return worker id
worker_id_t GetWorkerByNumaId(numa_id_t node_id); worker_id_t GetWorkerByNumaId(numa_id_t node_id) const;
/// \brief Randomly pick a worker /// \brief Randomly pick a worker
/// \return worker id /// \return worker id
worker_id_t GetRandomWorker(); worker_id_t GetRandomWorker() const;
/// \brief Check if we bind threads to numa cores /// \brief Check if we bind threads to numa cores
bool IsNumaAffinityOn() const { return numa_affinity_; } bool IsNumaAffinityOn() const { return numa_affinity_; }
@ -227,6 +227,7 @@ class CacheServer : public Service {
std::shared_ptr<CacheServerHW> hw_info_; std::shared_ptr<CacheServerHW> hw_info_;
std::map<worker_id_t, Task *> numa_tasks_; std::map<worker_id_t, Task *> numa_tasks_;
bool numa_affinity_; bool numa_affinity_;
std::vector<int32_t> shutdown_qIDs_;
/// \brief Constructor /// \brief Constructor
/// \param spill_path Top directory for spilling buffers to. /// \param spill_path Top directory for spilling buffers to.
@ -314,7 +315,7 @@ class CacheServer : public Service {
/// \brief A proper shutdown of the server /// \brief A proper shutdown of the server
/// \return Status object /// \return Status object
Status GlobalShutdown(); Status GlobalShutdown(CacheServerRequest *);
/// \brief Find keys that will be cache miss /// \brief Find keys that will be cache miss
/// \return Status object /// \return Status object
@ -330,6 +331,13 @@ class CacheServer : public Service {
/// \brief Connect request by a pipeline /// \brief Connect request by a pipeline
Status ConnectReset(CacheRequest *rq); Status ConnectReset(CacheRequest *rq);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out);
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -209,6 +209,10 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) { const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
SharedLock rw(&rw_lock_); SharedLock rw(&rw_lock_);
if (st_ == CacheServiceState::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v; std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
datalocator_v.reserve(v.size()); datalocator_v.reserve(v.size());
for (auto row_id : v) { for (auto row_id : v) {
@ -225,76 +229,6 @@ Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::
return Status::OK(); return Status::OK();
} }
Status CacheService::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
SharedLock rw(&rw_lock_);
if (st_ == CacheServiceState::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
CacheServer &cs = CacheServer::GetInstance();
int32_t numQ = cs.GetNumGrpcWorkers();
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
std::vector<CacheServerRequest *> cache_rq_list;
auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
const auto num_elements = p->rows()->size();
auto connection_id = p->connection_id();
cache_rq_list.reserve(num_elements);
int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
offset_array[0] = data_offset;
for (auto i = 0; i < num_elements; ++i) {
auto data_locator = p->rows()->Get(i);
auto node_id = data_locator->node_id();
size_t sz = data_locator->size();
void *source_addr = reinterpret_cast<void *>(data_locator->addr());
auto key = data_locator->key();
// Please read the comment in CacheServer::BatchFetchRows where we allocate
// the buffer big enough so each thread (which we are going to dispatch) will
// not run into false sharing problem. We are going to round up sz to 4k.
auto sz_4k = round_up_4K(sz);
offset_array[i + 1] = offset_array[i] + sz_4k;
if (sz > 0) {
WritableSlice row_data(*out, offset_array[i], sz);
// Get a request and send to the proper worker (at some numa node) to do the fetch.
worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker();
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq));
cache_rq_list.push_back(cache_rq);
// Set up all the necessarily field.
cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
cache_rq->rq_.set_connection_id(connection_id);
cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
auto dest_addr = row_data.GetMutablePointer();
flatbuffers::FlatBufferBuilder fb2;
FetchRowMsgBuilder bld(fb2);
bld.add_key(key);
bld.add_size(sz);
bld.add_source_addr(reinterpret_cast<int64_t>(source_addr));
bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr));
auto offset = bld.Finish();
fb2.Finish(offset);
cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq));
}
}
// Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding
// any lock while we can wait for a long time.
rw.Unlock();
Status rc;
for (CacheServerRequest *rq : cache_rq_list) {
RETURN_IF_NOT_OK(rq->Wait());
if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) {
rc = rq->rc_;
}
RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq));
}
return rc;
}
Status CacheService::InternalFetchRow(const FetchRowMsg *p) { Status CacheService::InternalFetchRow(const FetchRowMsg *p) {
RETURN_UNEXPECTED_IF_NULL(p); RETURN_UNEXPECTED_IF_NULL(p);
SharedLock rw(&rw_lock_); SharedLock rw(&rw_lock_);

@ -75,13 +75,6 @@ class CacheService : public Service {
Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &); const std::shared_ptr<flatbuffers::FlatBufferBuilder> &);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, WritableSlice *out) const;
/// \brief Getter function /// \brief Getter function
/// \return Spilling path /// \return Spilling path
Path GetSpillPath() const; Path GetSpillPath() const;

@ -87,6 +87,7 @@ class WritableSlice : public ReadableSlice {
public: public:
friend class StorageContainer; friend class StorageContainer;
friend class CacheService; friend class CacheService;
friend class CacheServer;
/// \brief Default constructor /// \brief Default constructor
WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {}
/// \brief This form of a constructor takes a pointer and its size. /// \brief This form of a constructor takes a pointer and its size.

@ -42,7 +42,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheCApiSamplerNull) {
// Create an ImageFolder Dataset, this folder_path only has 2 images in it // Create an ImageFolder Dataset, this folder_path only has 2 images in it
std::string folder_path = datasets_root_path_ + "/testImageNetData/train/"; std::string folder_path = datasets_root_path_ + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, nullptr, {}, {}, some_cache); std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, nullptr, {}, {}, some_cache);
EXPECT_EQ(ds, nullptr); EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
// Now the parameter check for ImageFolderNode would fail and we would end up with a nullptr iter.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
} }
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheImageFolderCApi) { TEST_F(MindDataTestCacheOp, DISABLED_TestCacheImageFolderCApi) {

@ -121,7 +121,7 @@ HandleRcExit $? 0 1
# find a port that is occupied using netstat # find a port that is occupied using netstat
if [ -x "$(command -v netstat)" ]; then if [ -x "$(command -v netstat)" ]; then
port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1)
if [ ${port} -gt 1025 ]; then if [ ${port} -gt 1025 ]; then
# start cache server with occupied port # start cache server with occupied port
cmd="${CACHE_ADMIN} --start -p ${port}" cmd="${CACHE_ADMIN} --start -p ${port}"
@ -171,7 +171,12 @@ HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w illegal" cmd="${CACHE_ADMIN} --start -w illegal"
CacheAdminCmd "${cmd}" 1 CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0 HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w 101" num_cpu=$(grep -c processor /proc/cpuinfo)
if [ $num_cpu -lt 100 ]; then
cmd="${CACHE_ADMIN} --start -w 101"
else
cmd="${CACHE_ADMIN} --start -w ${num_cpu}+1"
fi
CacheAdminCmd "${cmd}" 1 CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0 HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w 9999999" cmd="${CACHE_ADMIN} --start -w 9999999"

@ -1,10 +0,0 @@
~/cache/cache_admin --start
session_id=$(~/cache/cache_admin -g | awk '{print $NF}')
export SESSION_ID=${session_id}
pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop &
pid=("$!")
sleep 2
~/cache/cache_admin --stop
sleep 1
wait ${pid}
Loading…
Cancel
Save