diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index ebb3b334fd..1dfd9d6dea 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -24,6 +24,11 @@ if (ENABLE_TDTQUE) add_definitions(-D ENABLE_TDTQUE) message(STATUS "TDT queue is enabled") endif () +if (MS_BUILD_GRPC) + set (ENABLE_CACHE true) + add_definitions(-D ENABLE_CACHE) + message(STATUS "Cache is enabled") +endif() # conde coverage # option(ENABLE_COVERAGE "Enable code coverage report" OFF) @@ -47,10 +52,6 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") -include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") -set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") -ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) - ################## Include sub-modules ############################### add_subdirectory(util) add_subdirectory(core) @@ -70,8 +71,6 @@ add_dependencies(engine-datasetops-source-sampler core) add_dependencies(engine-datasetops core) add_dependencies(engine-datasetops-mapop core) add_dependencies(engine-opt core) -add_dependencies(engine-cache-client core) -add_dependencies(engine-cache-server core) add_dependencies(engine-perf core) add_dependencies(engine-gnn core) add_dependencies(engine core) @@ -85,7 +84,11 @@ endif() if (ENABLE_TDTQUE) add_dependencies(engine-tdt core) endif () - +if (ENABLE_CACHE) + add_dependencies(engine-datasetops engine-cache-client) + add_dependencies(engine-cache-client core) + add_dependencies(engine-cache-server core) +endif () ################### Create _c_dataengine Library ###################### set(submodules $ @@ -105,7 +108,6 @@ set(submodules $ $ $ - $ $ $ $ @@ -123,8 +125,6 @@ else () add_library(_c_dataengine SHARED ${submodules}) endif () -add_dependencies(_c_dataengine generated_engine_files) - if (ENABLE_PYTHON) set_target_properties(_c_dataengine PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}" @@ -187,6 +187,6 @@ else() endif () endif() -if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") +if (MS_BUILD_GRPC) target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) -endif() \ No newline at end of file +endif() diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc index aa5ba9e561..018611332a 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc @@ -22,7 +22,25 @@ namespace dataset { PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { (void)py::class_>(*m, "CacheClient") - .def(py::init()); + .def( + py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) { + std::shared_ptr cc; + CacheClient::Builder builder; + builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize( + prefetch_sz); + THROW_IF_ERROR(builder.Build(&cc)); + return cc; + })) + .def("GetStat", [](CacheClient &cc) { + CacheServiceStat stat{}; + THROW_IF_ERROR(cc.GetStat(&stat)); + return stat; + }); + (void)py::class_(*m, "CacheServiceStat") + .def(py::init<>()) + .def_readwrite("avg_cache_sz", &CacheServiceStat::avg_cache_sz) + .def_readwrite("num_mem_cached", &CacheServiceStat::num_mem_cached) + .def_readwrite("num_disk_cached", &CacheServiceStat::num_disk_cached); })); } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index be875c5028..d2ef2c14c9 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -72,7 +72,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10; // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) constexpr uint8_t kCVInvalidType = 255; -using connection_id_type = int64_t; +using connection_id_type = uint64_t; +using session_id_type = uint32_t; using row_id_type = int64_t; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt index 342eac8fb4..0b8d526f56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt @@ -20,10 +20,8 @@ if (ENABLE_PYTHON) target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) endif() +add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop) + if (ENABLE_TDTQUE) - add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf - engine-cache-client engine-cache-server engine-datasetops-mapop) -else () - add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf - engine-cache-client engine-cache-server engine-datasetops-mapop) + add_dependencies(engine engine-tdt) endif () diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt index 5e7ebea176..5f962be8b7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt @@ -1,8 +1,47 @@ +include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") +set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") +ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) + file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + add_library(engine-cache-client OBJECT cache_client.cc + cache_fbb.cc cache_request.cc) -add_library(engine-cache-server OBJECT - cache_service.cc - cache_server.cc) + +if (ENABLE_CACHE) + ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto) + target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc) + + add_library(engine-cache-server OBJECT + ${CACHE_GRPC_SRCS} + cache_grpc_server.cc + cache_arena.cc + cache_service.cc + cache_server.cc) + + add_executable(cache_server cache_main.cc) + target_link_libraries(cache_server + engine-cache-server + $ + mindspore + mindspore::glog + mindspore::protobuf + mindspore::grpc++ + mindspore_gvar + ${PYTHON_LIBRARIES} + ${SECUREC_LIBRARY} + pthread) + + add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) + target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES} mindspore::glog) + + add_dependencies(engine-cache-server generated_engine_files) + +else () + ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto) + target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS}) +endif () + +add_dependencies(engine-cache-client generated_engine_files) diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc new file mode 100644 index 0000000000..6de1e61bbe --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include +#include +#ifdef USE_GLOG +#include +#endif +#include "minddata/dataset/engine/cache/cache_admin_arg.h" + +namespace ds = mindspore::dataset; + +int main(int argc, char **argv) { + ds::Status rc; + ds::CacheAdminArgHandler args; + std::stringstream arg_stream; + +#ifdef USE_GLOG + FLAGS_log_dir = "/tmp"; + google::InitGoogleLogging(argv[0]); +#endif + + std::string warningMsg; + warningMsg.reserve(512); + warningMsg += "WARNING:\n"; + warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research"; + warningMsg += " purposes at this time.\n"; + warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n"; + + // A warning message until the code is mature enough. + std::cerr << warningMsg << std::endl; + + if (argc == 1) { + args.Help(); + return 0; + } + + // ingest all the args into a string stream for parsing + for (int i = 1; i < argc; ++i) { + arg_stream << " " << std::string(argv[i]); + } + + // Parse the args + rc = args.ParseArgStream(&arg_stream); + if (!rc.IsOk()) { + std::cerr << rc.ToString() << std::endl; + return 1; + } + + // Execute the command + rc = args.RunCommand(); + if (!rc.IsOk()) { + std::cerr << rc.ToString() << std::endl; + return 1; + } + + return 0; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc new file mode 100644 index 0000000000..892f7842ef --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -0,0 +1,396 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_admin_arg.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1"; +const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; +const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; + +CacheAdminArgHandler::CacheAdminArgHandler() + : port_(kDefaultPort), + session_id_(0), + num_workers_(kDefaultNumWorkers), + shm_mem_sz_(kDefaultSharedMemorySizeInGB), + log_level_(kDefaultLogLevel), + hostname_(kDefaultHost), + spill_dir_(kDefaultSpillDir), + command_id_(CommandId::kCmdUnknown) { + // Initialize the command mappings + arg_map_["-h"] = ArgValue::kArgHost; + arg_map_["--hostname"] = ArgValue::kArgHost; + arg_map_["-p"] = ArgValue::kArgPort; + arg_map_["--port"] = ArgValue::kArgPort; + arg_map_["--start"] = ArgValue::kArgStart; + arg_map_["--stop"] = ArgValue::kArgStop; + arg_map_["--help"] = ArgValue::kArgHelp; + arg_map_["--generate_session"] = ArgValue::kArgGenerateSession; + arg_map_["-g"] = ArgValue::kArgGenerateSession; + arg_map_["--destroy_session"] = ArgValue::kArgDestroySession; + arg_map_["-d"] = ArgValue::kArgDestroySession; + arg_map_["--spilldir"] = ArgValue::kArgSpillDir; + arg_map_["-s"] = ArgValue::kArgSpillDir; + arg_map_["-w"] = ArgValue::kArgNumWorkers; + arg_map_["--workers"] = ArgValue::kArgNumWorkers; + arg_map_["-m"] = ArgValue::kArgSharedMemorySize; + arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize; + arg_map_["-l"] = ArgValue::kArgLogLevel; + arg_map_["--minloglevel"] = ArgValue::kArgLogLevel; + // Initialize argument tracker with false values + for (int16_t i = 0; i < static_cast(ArgValue::kArgNumArgs); ++i) { + ArgValue currAV = static_cast(i); + used_args_[currAV] = false; + } +} + +Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, + CommandId command_id) { + // Detect if the user tried to provide this argument more than once + ArgValue selected_arg = arg_map_[option]; + if (used_args_[selected_arg]) { + std::string err_msg = "The " + option + " argument was given more than once."; + return Status(StatusCode::kSyntaxError, err_msg); + } + + // Flag that this arg is used now + used_args_[selected_arg] = true; + + // Some options are just arguments, for example "--port 50052" is not a command, it's just a argument. + // Other options are actual commands, for example "--destroy_session 1234". This executes the destroy session. + // If this option is also a command, make sure there has not been multiple commands given before assigning it. + if (command_id != CommandId::kCmdUnknown) { + if (command_id_ != CommandId::kCmdUnknown) { + std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; + return Status(StatusCode::kSyntaxError, err_msg); + } else { + command_id_ = command_id; + } + } + + std::string value_as_string; + + // Fetch the argument from the arg stream into a string + *arg_stream >> value_as_string; + if (value_as_string.empty()) { + std::string err_msg = option + " option requires an argument field. Syntax: " + option + " "; + return Status(StatusCode::kSyntaxError, err_msg); + } + + // Now, attempt to convert the value into it's string format for output + try { + *out_arg = std::stoul(value_as_string); + } catch (const std::exception &e) { + std::string err_msg = "Invalid numeric value: " + value_as_string; + return Status(StatusCode::kSyntaxError, err_msg); + } + + return Status::OK(); +} + +Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, + CommandId command_id) { + // Detect if the user tried to provide this argument more than once + ArgValue selected_arg = arg_map_[option]; + if (used_args_[selected_arg]) { + std::string err_msg = "The " + option + " argument was given more than once."; + return Status(StatusCode::kSyntaxError, err_msg); + } + + // Flag that this arg is used now + used_args_[selected_arg] = true; + + // Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument. + // Other options are actual commands, for example "--start". + // If this option is also a command, make sure there has not been multiple commands given before assigning it. + if (command_id != CommandId::kCmdUnknown) { + if (command_id_ != CommandId::kCmdUnknown) { + std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; + return Status(StatusCode::kSyntaxError, err_msg); + } else { + command_id_ = command_id; + } + } + + // If there is no argument to get, such as the --start command, then out_arg will be a nullptr. + if (out_arg != nullptr) { + // Fetch the argument from the arg stream into a string + *arg_stream >> *out_arg; + if (out_arg->empty()) { + std::string err_msg = option + " option requires an argument field. Syntax: " + option + " "; + return Status(StatusCode::kSyntaxError, err_msg); + } + } + + return Status::OK(); +} + +Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { + std::string tok; + while (*arg_stream >> tok) { + switch (arg_map_[tok]) { + case ArgValue::kArgHost: { + RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream)); + break; + } + case ArgValue::kArgPort: { + RETURN_IF_NOT_OK(AssignArg(tok, &port_, arg_stream)); + break; + } + case ArgValue::kArgStart: { + RETURN_IF_NOT_OK(AssignArg(tok, static_cast(nullptr), arg_stream, CommandId::kCmdStart)); + break; + } + case ArgValue::kArgStop: { + RETURN_IF_NOT_OK(AssignArg(tok, static_cast(nullptr), arg_stream, CommandId::kCmdStop)); + break; + } + case ArgValue::kArgGenerateSession: { + RETURN_IF_NOT_OK( + AssignArg(tok, static_cast(nullptr), arg_stream, CommandId::kCmdGenerateSession)); + break; + } + case ArgValue::kArgHelp: { + command_id_ = CommandId::kCmdHelp; + break; + } + case ArgValue::kArgDestroySession: { + // session_id is an unsigned type. We may need to template the AssignArg function so that + // it can handle different flavours of integers instead of just int32_t. + int32_t session_int; + RETURN_IF_NOT_OK(AssignArg(tok, &session_int, arg_stream, CommandId::kCmdDestroySession)); + session_id_ = session_int; + break; + } + case ArgValue::kArgNumWorkers: { + RETURN_IF_NOT_OK(AssignArg(tok, &num_workers_, arg_stream)); + break; + } + case ArgValue::kArgSpillDir: { + RETURN_IF_NOT_OK(AssignArg(tok, &spill_dir_, arg_stream)); + break; + } + case ArgValue::kArgSharedMemorySize: { + RETURN_IF_NOT_OK(AssignArg(tok, &shm_mem_sz_, arg_stream)); + break; + } + case ArgValue::kArgLogLevel: { + RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream)); + break; + } + default: { + // Save space delimited trailing arguments + trailing_args_ += (" " + tok); + break; + } + } + } + + RETURN_IF_NOT_OK(Validate()); + + return Status::OK(); +} + +Status CacheAdminArgHandler::Validate() { + // This sanity check is delayed until now in case there may be valid use-cases of trailing args. + // Any unhandled arguments at this point is an error. + if (!trailing_args_.empty()) { + std::string err_msg = "Invalid arguments provided: " + trailing_args_; + return Status(StatusCode::kSyntaxError, err_msg); + } + + // The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to + // run. + if (command_id_ == CommandId::kCmdUnknown) { + std::string err_msg = "No command provided"; + return Status(StatusCode::kSyntaxError, err_msg); + } + + // Additional checks here + if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value."); + if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); + // port range check? + + return Status::OK(); +} + +Status CacheAdminArgHandler::RunCommand() { + switch (command_id_) { + case CommandId::kCmdHelp: { + Help(); + break; + } + case CommandId::kCmdStart: { + RETURN_IF_NOT_OK(StartServer()); + break; + } + case CommandId::kCmdStop: { + RETURN_IF_NOT_OK(StopServer()); + break; + } + case CommandId::kCmdGenerateSession: { + CacheClientGreeter comm(hostname_, port_, 1); + RETURN_IF_NOT_OK(comm.ServiceStart()); + auto rq = std::make_shared(); + RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + std::cout << rq->GetSessionId() << std::endl; + break; + } + case CommandId::kCmdDestroySession: { + CacheClientGreeter comm(hostname_, port_, 1); + RETURN_IF_NOT_OK(comm.ServiceStart()); + CacheClientInfo cinfo; + cinfo.set_session_id(session_id_); + auto rq = std::make_shared(cinfo); + RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + std::cout << "Drop session successful" << std::endl; + break; + } + default: { + RETURN_STATUS_UNEXPECTED("Invalid cache admin command id."); + break; + } + } + + return Status::OK(); +} + +Status CacheAdminArgHandler::StartServer() { + // There currently does not exist any "install path" or method to identify which path the installed binaries will + // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the + // cache_admin binary (this process). + const std::string self_proc = "/proc/self/exe"; + std::string canonical_path; + canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use. + // Some lower level OS library calls are needed here to determine the binary path. + // Fetch the path of this binary for admin_cache into C character array and then truncate off the binary name so that + // we are left with only the absolute path + if (realpath(self_proc.data(), canonical_path.data()) == nullptr) { + std::string err_msg = "Failed to identify cache admin binary path: " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(err_msg); + } + canonical_path.resize(strlen(canonical_path.data())); + int last_seperator = canonical_path.find_last_of('/'); + CHECK_FAIL_RETURN_UNEXPECTED(last_seperator != std::string::npos, "No / found"); + // truncate the binary name so we are left with the absolute path of cache_admin binary + canonical_path.resize(last_seperator + 1); + std::string cache_server_binary = canonical_path + std::string(kServerBinary); + + // Create a pipe before we fork. If all goes well, the child will run as a daemon in the background + // and never returns until shutdown. If there is any error, the child will notify us through the pipe. + int fd[2]; + if (pipe(fd) == -1) { + std::string err_msg = "Failed to create a pipe for communication " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // fork the child process to become the daemon + pid_t pid; + pid = fork(); + + // failed to fork + if (pid < 0) { + std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(err_msg); + } else if (pid > 0) { + // As a parent, we close the write end. We only listen. + close(fd[1]); + dup2(fd[0], 0); + close(fd[0]); + wait(nullptr); + std::string msg; + const int32_t buf_sz = 1024; + msg.resize(buf_sz); + auto n = read(0, msg.data(), buf_sz); + if (n < 0) { + std::string err_msg = "Failed to read from pipeline " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(err_msg); + } + msg.resize(n); + std::cout << msg << std::endl; + return Status::OK(); + } else { + // Child here ... + // Close all stdin, redirect stdout and stderr to the write end of the pipe. + close(fd[0]); + dup2(fd[1], 1); + dup2(fd[1], 2); + close(0); + close(fd[1]); + // exec the cache server binary in this process + std::string port_string = std::to_string(port_); + std::string workers_string = std::to_string(num_workers_); + std::string shared_memory_string = std::to_string(shm_mem_sz_); + std::string minloglevel_string = std::to_string(log_level_); + std::string daemonize_string = "true"; + + char *argv[8]; + argv[0] = cache_server_binary.data(); // First arg is usually the binary name + argv[1] = spill_dir_.data(); + argv[2] = workers_string.data(); + argv[3] = port_string.data(); + argv[4] = shared_memory_string.data(); + argv[5] = minloglevel_string.data(); + argv[6] = daemonize_string.data(); + argv[7] = nullptr; + + // Now exec the binary + execv(argv[0], argv); + // If the exec was successful, this line will never be reached due to process image being replaced. + // ..unless exec failed. + std::string err_msg = "Failed to exec cache server: " + cache_server_binary; + std::cerr << err_msg << std::endl; + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +Status CacheAdminArgHandler::StopServer() { + CacheClientGreeter comm(hostname_, port_, 1); + RETURN_IF_NOT_OK(comm.ServiceStart()); + auto rq = std::make_shared(); + RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + return Status::OK(); +} + +void CacheAdminArgHandler::Help() { + std::cerr << "Syntax:\n"; + std::cerr << " cache_admin [--start | --stop]\n"; + std::cerr << " [ [-h | --hostname] ]\n"; + std::cerr << " [ [-p | --port] ]\n"; + std::cerr << " [ [-g | --generate_session] ]\n"; + std::cerr << " [ [-d | --destroy_session] ]\n"; + std::cerr << " [ [-w | --workers] ]\n"; + std::cerr << " [ [-s | --spilldir] ]\n"; + std::cerr << " [ [-m | --shared_memory_size] ]\n"; + std::cerr << " [ [-l | --minloglevel] ]\n"; + std::cerr << " [--help]" << std::endl; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h new file mode 100644 index 0000000000..e48916c482 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/cache/cache_client.h" + +namespace mindspore { +namespace dataset { + +class CacheAdminArgHandler { + public: + static constexpr int32_t kDefaultPort = 50052; + static constexpr int32_t kDefaultNumWorkers = 32; + static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; + static constexpr int32_t kDefaultLogLevel = 1; + static const char kDefaultHost[]; + static const char kServerBinary[]; + static const char kDefaultSpillDir[]; + + // These are the actual command types to execute + enum class CommandId : int16_t { + kCmdHelp = 0, + kCmdStart = 1, + kCmdStop = 2, + kCmdGenerateSession = 3, + kCmdDestroySession = 4, + kCmdUnknown = 32767 + }; + + CacheAdminArgHandler(); + + ~CacheAdminArgHandler() = default; + + Status ParseArgStream(std::stringstream *arg_stream); + + Status RunCommand(); + + void Help(); + + private: + // These are the supported argument string integer mappings + enum class ArgValue : int16_t { + kArgUnknown = 0, // Must be at position 0. invalid map lookups in arg_map_ default to value 0 + kArgStart = 1, + kArgStop = 2, + kArgHost = 3, + kArgPort = 4, + kArgHelp = 5, + kArgGenerateSession = 6, + kArgDestroySession = 7, + kArgSpillDir = 8, + kArgNumWorkers = 9, + kArgSharedMemorySize = 10, + kArgLogLevel = 11, + kArgNumArgs = 12 // Must be the last position to provide a count + }; + + Status StartServer(); + + Status StopServer(); + + Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, + CommandId command_id = CommandId::kCmdUnknown); + + Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, + CommandId command_id = CommandId::kCmdUnknown); + + Status Validate(); + + CommandId command_id_; + int32_t port_; + int32_t num_workers_; + int32_t shm_mem_sz_; + int32_t log_level_; + session_id_type session_id_; + std::string hostname_; + std::string spill_dir_; + std::string trailing_args_; + std::map arg_map_; + std::map used_args_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc new file mode 100644 index 0000000000..e358ac3573 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/cache/cache_arena.h" +#include "minddata/dataset/util/path.h" +namespace mindspore { +namespace dataset { +CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) + : Arena::Arena(val_in_GB * 1024), port_(port), shmid_(-1) {} + +CachedSharedMemoryArena::~CachedSharedMemoryArena() { +#if CACHE_LOCAL_CLIENT + if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast(-1)) { + shmdt(this->ptr_); + } + this->ptr_ = nullptr; + if (shmid_ != -1) { + shmctl(shmid_, IPC_RMID, nullptr); + // Also remove the path we use to generate ftok. + Path p(PortToUnixSocketPath(port_)); + (void)p.Remove(); + } +#endif +} + +Status CachedSharedMemoryArena::CreateArena(std::unique_ptr *out, int32_t port, + size_t val_in_GB) { + RETURN_UNEXPECTED_IF_NULL(out); +#if CACHE_LOCAL_CLIENT + auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); + if (ba == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + // Transfer the ownership of this pointer. Any future error in the processing we will have + // the destructor of *out to deal. + (*out).reset(ba); + // Generate the ftok using a combination of port. + int err; + auto shm_key = PortToFtok(port, &err); + if (shm_key == (key_t)-1) { + std::string errMsg = "Ftok failed with errno " + std::to_string(err); + RETURN_STATUS_UNEXPECTED(errMsg); + } + auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; + ba->shmid_ = shmget(shm_key, ba->size_in_bytes_, IPC_CREAT | IPC_EXCL | access_mode); + if (ba->shmid_) { + ba->ptr_ = shmat(ba->shmid_, nullptr, 0); + if (ba->ptr_ == reinterpret_cast(-1)) { + RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); + } + } else { + RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); + } + uint64_t num_blks = ba->size_in_bytes_ / ARENA_BLK_SZ; + MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; + ba->tr_.Insert(0, num_blks); +#endif + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h new file mode 100644 index 0000000000..3d0ba7cbbc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ + +#include +#include +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/engine/cache/cache_common.h" +namespace mindspore { +namespace dataset { +/// This is a derived class of Arena but resides in shared memory +class CachedSharedMemoryArena : public Arena { + public: + ~CachedSharedMemoryArena() override; + /// \brief Create an Arena in shared memory + /// \param[out] p_ba Pointer to a unique_ptr + /// \param shmkey Shared memory key + /// \param val_in_GB size of shared memory in gigabyte + /// \return Status object + static Status CreateArena(std::unique_ptr *out, int32_t port, size_t val_in_GB); + + /// \brief This returns where we attach to the shared memory. + /// Some gRPC requests will ask for a shared memory block, and + /// we can't return the absolute address as this makes no sense + /// in the client. So instead we will return an address relative + /// to the base address of the shared memory where we attach to. + /// \return Base address of the shared memory. + const void *SharedMemoryBaseAddr() const { return this->ptr_; } + + private: + int32_t port_; + int shmid_; + /// Private constructor. Not to be called directly. + CachedSharedMemoryArena(int32_t port, size_t val_in_GB); +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index 04746131bb..428e0f4eeb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -17,29 +17,45 @@ #include #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/engine/cache/cache_fbb.h" #include "minddata/dataset/util/bit.h" namespace mindspore { namespace dataset { // Constructor -CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) - : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} +CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, + int32_t port, int32_t num_workers, int32_t prefetch_size) + : server_connection_id_(0), + cache_mem_sz_(cache_mem_sz), + spill_(spill), + local_bypass_(false), + hostname_(std::move(hostname)), + port_(port), + num_workers_(num_workers), + prefetch_size_(prefetch_size) { + cinfo_.set_session_id(session_id); + comm_ = std::make_shared(hostname_, port_, num_workers_); +} // print method for display cache details void CacheClient::Print(std::ostream &out) const { - out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ - << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ - << "\n Spilling: " << std::boolalpha << spill_; + out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc() + << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz() + << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname() + << "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers() + << "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha + << SupportLocalClient(); } Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { - CacheRowRequest rq(server_connection_id_, cookie()); - RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); + auto rq = std::make_shared(server_connection_id_, cookie(), SupportLocalClient()); + RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); if (row_id_from_server != nullptr) { - *row_id_from_server = rq.GetRowIdAfterCache(); + *row_id_from_server = rq->GetRowIdAfterCache(); } return Status::OK(); } @@ -47,29 +63,19 @@ Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_serv Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { std::unique_ptr db_ptr = std::move(in); auto num_rows = db_ptr->NumRows(); - std::vector all_rows; + // We will send the requests async first on all rows and do a final wait. if (num_rows > 0) { - all_rows.reserve(num_rows); - // Break down the DataBuffer into TensorRow. We will send the requests async - // and then do a final wait. - MemGuard rq_arr; - RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); - CacheServer &cs = CacheServer::GetInstance(); + auto arr = std::make_unique[]>(num_rows); for (auto i = 0; i < num_rows; ++i) { TensorRow row; - auto rq = rq_arr[i]; RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); - RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); - RETURN_IF_NOT_OK(cs.PushRequest(rq)); - // We can't let row go out of scope. Otherwise it will free all the tensor memory. - // So park it in the vector. When this function go out of scope, its memory - // will be freed. - all_rows.push_back(std::move(row)); + arr[i] = std::make_shared(server_connection_id_, cookie(), SupportLocalClient()); + RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row)); + RETURN_IF_NOT_OK(PushRequest(arr[i])); } - // Now we wait for the requests to be done. + // Now we wait for them to come back for (auto i = 0; i < num_rows; ++i) { - auto rq = rq_arr[i]; - RETURN_IF_NOT_OK(rq->Wait()); + RETURN_IF_NOT_OK(arr[i]->Wait()); } } return Status::OK(); @@ -77,11 +83,21 @@ Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { Status CacheClient::GetRows(const std::vector &row_id, TensorTable *out) const { RETURN_UNEXPECTED_IF_NULL(out); - BatchFetchRequest rq(server_connection_id_, row_id); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - RETURN_IF_NOT_OK(rq.RestoreRows(out)); - return Status::OK(); + auto rq = std::make_shared(server_connection_id_, row_id, SupportLocalClient()); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + int64_t mem_addr; + Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr); + // Free the memory by sending a request back to the server. + if (mem_addr != -1) { + auto mfree_req = std::make_shared(server_connection_id_, mem_addr); + Status rc2 = PushRequest(mfree_req); + // But we won't wait for the result for the sake of performance. + if (rc.IsOk() && rc2.IsError()) { + rc = rc2; + } + } + return rc; } Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { @@ -108,40 +124,44 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { // to create a cache and some other tree is trying to use the same cache. // That is allowed, however the crc better match! if (server_connection_id_) { - if (cache_crc_ != tree_crc) { + if (cinfo_.crc() != tree_crc) { RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); } // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should // skip the build phase. lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. - CacheClient::ServiceStat stat{}; + CacheServiceStat stat{}; RETURN_IF_NOT_OK(GetStat(&stat)); if (stat.cache_service_state == static_cast(CacheService::State::kFetchPhase)) { return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); } } else { - cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client - // Combine the session and crc. This will form our client cache identifier. - connection_id_type connection_identification = (static_cast(session_id_) << 32) | cache_crc_; + cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client // Now execute the cache create request using this identifier and other configs - BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; + CreateCacheRequest::CreateCacheFlag createFlag = CreateCacheRequest::CreateCacheFlag::kNone; if (spill_) { - createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; + createFlag |= CreateCacheRequest::CreateCacheFlag::kSpillToDisk; } if (generate_id) { - createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; + createFlag |= CreateCacheRequest::CreateCacheFlag::kGenerateRowId; } - CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - Status rc = rq.Wait(); + // Start the comm layer to receive reply + RETURN_IF_NOT_OK(comm_->ServiceStart()); + // Initiate connection + auto rq = std::make_shared(cinfo_, cache_mem_sz_, createFlag); + RETURN_IF_NOT_OK(PushRequest(rq)); + Status rc = rq->Wait(); if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { - server_connection_id_ = rq.GetServerConnectionId(); + std::string cookie; + rq->ParseResult(&server_connection_id_, &cookie); if (rc.IsOk()) { // The 1st guy creating the cache will get a cookie back. // But this object may be shared among pipelines and we don't want // overwrite it. - cookie_ = rq.cookie(); + cookie_ = cookie; } + // Attach to shared memory for local client + RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); } // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the // CacheOp to bypass the build phase. @@ -152,57 +172,57 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { Status CacheClient::PurgeCache() { UniqueLock lck(&mux_); - PurgeCacheRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - return rq.Wait(); + auto rq = std::make_shared(server_connection_id_); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + return Status::OK(); } Status CacheClient::DestroyCache() { UniqueLock lck(&mux_); - DestroyCacheRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - return rq.Wait(); + auto rq = std::make_shared(server_connection_id_); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + return Status::OK(); } -Status CacheClient::GetStat(ServiceStat *stat) { +Status CacheClient::GetStat(CacheServiceStat *stat) { SharedLock lck(&mux_); RETURN_UNEXPECTED_IF_NULL(stat); - GetStatRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - stat->num_disk_cached = rq.GetNumDiskCached(); - stat->num_mem_cached = rq.GetNumMemCached(); - stat->min_row_id = rq.GetMinRowId(); - stat->max_row_id = rq.GetMaxRowId(); - stat->cache_service_state = rq.GetState(); + auto rq = std::make_shared(server_connection_id_); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + rq->GetStat(stat); return Status::OK(); } Status CacheClient::CacheSchema(const std::unordered_map &map) { SharedLock lck(&mux_); - CacheSchemaRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); + auto rq = std::make_shared(server_connection_id_); + RETURN_IF_NOT_OK(rq->SerializeCacheSchemaRequest(map)); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); return Status::OK(); } Status CacheClient::FetchSchema(std::unordered_map *map) { SharedLock lck(&mux_); RETURN_UNEXPECTED_IF_NULL(map); - FetchSchemaRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - *map = rq.GetColumnMap(); + auto rq = std::make_shared(server_connection_id_); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + *map = rq->GetColumnMap(); return Status::OK(); } Status CacheClient::BuildPhaseDone() const { SharedLock lck(&mux_); - BuildPhaseDoneRequest rq(server_connection_id_, cookie()); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); + auto rq = std::make_shared(server_connection_id_, cookie()); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); return Status::OK(); } + +Status CacheClient::PushRequest(std::shared_ptr rq) const { return comm_->HandleRequest(std::move(rq)); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index 963d2e7e89..99b2209366 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -23,9 +23,13 @@ #include #include +#include "minddata/dataset/core/config_manager.h" +#ifdef ENABLE_CACHE +#include "minddata/dataset/engine/cache/cache_grpc_client.h" +#else +#include "minddata/dataset/engine/cache/stub/cache_grpc_client.h" +#endif #include "minddata/dataset/engine/data_buffer.h" -#include "minddata/dataset/engine/cache/cache_server.h" -#include "minddata/dataset/engine/cache/de_tensor_generated.h" #include "minddata/dataset/util/lock.h" namespace mindspore { @@ -35,18 +39,120 @@ namespace dataset { /// rows, etc. class CacheClient { public: + friend class CacheMergeOp; + + /// \brief A builder to help creating a CacheClient object + class Builder { + public: + Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) { + std::shared_ptr cfg = GlobalContext::config_manager(); + hostname_ = "127.0.0.1"; + port_ = 50052; + num_workers_ = cfg->num_parallel_workers(); + prefetch_size_ = 20; // rows_per_buf is too small (1 by default). + } + + /// Setter function to set the session id + /// \param session_id + /// \return Builder object itself. + Builder &SetSessionId(session_id_type session_id) { + session_id_ = session_id; + return *this; + } + + /// Setter function to set the cache memory size + /// \param cache_mem_sz + /// \return Builder object itself + Builder &SetCacheMemSz(uint64_t cache_mem_sz) { + cache_mem_sz_ = cache_mem_sz; + return *this; + } + + /// Setter function to spill attribute + /// \param spill + /// Builder object itself + Builder &SetSpill(bool spill) { + spill_ = spill; + return *this; + } + + /// Setter function to set rpc hostname + /// \param host + /// \return Builder object itself + Builder &SetHostname(std::string host) { + hostname_ = std::move(host); + return *this; + } + + /// Setter function to set tcpip port + /// \param port + /// \return Builder object itself. + Builder &SetPort(int32_t port) { + port_ = port; + return *this; + } + + /// Setter function to set number of async rpc workers + /// \param num_workers + /// \return Builder object itself + Builder &SetNumWorkers(int32_t num_workers) { + num_workers_ = num_workers; + return *this; + } + + /// Setter function to set prefetch amount for fetching rows from cache server + /// \param prefetch_sz + /// \return Builder object itself + Builder &SetPrefetchSize(int32_t prefetch_sz) { + prefetch_size_ = prefetch_sz; + return *this; + } + + /// Getter functions + session_id_type getSessionId() const { return session_id_; } + uint64_t getCacheMemSz() const { return cache_mem_sz_; } + bool isSpill() const { return spill_; } + const std::string &getHostname() const { return hostname_; } + int32_t getPort() const { return port_; } + int32_t getNumWorkers() const { return num_workers_; } + int32_t getPrefetchSize() const { return prefetch_size_; } + + Status SanityCheck() { + CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); + CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); + return Status::OK(); + } + + Status Build(std::shared_ptr *out) { + RETURN_UNEXPECTED_IF_NULL(out); + RETURN_IF_NOT_OK(SanityCheck()); + *out = std::make_shared(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, + prefetch_size_); + return Status::OK(); + } + + private: + session_id_type session_id_; + uint64_t cache_mem_sz_; + bool spill_; + std::string hostname_; + int32_t port_; + int32_t num_workers_; + int32_t prefetch_size_; + }; + /// \brief Constructor /// \param session_id A user assigned session id for the current pipeline /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited /// \param spill Spill to disk if out of memory - CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); + CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, + int32_t num_workers, int32_t prefetch_size); /// \brief Destructor - ~CacheClient() = default; - - /// \brief Getter function for returning the current session id - /// \return session id - uint64_t session_id() const { return session_id_; } + ~CacheClient() { (void)comm_->ServiceStop(); } /// \brief Send a TensorRow to the cache server /// \param[in] row @@ -83,14 +189,7 @@ class CacheClient { /// \brief Get the statistics from a cache. /// \param[in/out] Pointer to a pre-allocated ServiceStat object /// \return Status object - struct ServiceStat { - int64_t num_mem_cached; - int64_t num_disk_cached; - row_id_type min_row_id; - row_id_type max_row_id; - int8_t cache_service_state; - }; - Status GetStat(ServiceStat *); + Status GetStat(CacheServiceStat *); /// \brief Cache the schema at the cache server /// \param map The unordered map of the schema @@ -122,18 +221,45 @@ class CacheClient { /// \return Cookie std::string cookie() const { return cookie_; } + /// \brief Send a request async to the server + /// \param rq BaseRequest + /// \return Status object + Status PushRequest(std::shared_ptr rq) const; + + /// \brief If the remote server supports local bypass using shared memory + /// \return boolean value + bool SupportLocalClient() const { return local_bypass_; } + + /// \brief Return the base memory address if we attach to any shared memory. + auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); } + + /// Getter functions + session_id_type session_id() const { return cinfo_.session_id(); } + uint64_t getCacheMemSz() const { return cache_mem_sz_; } + bool isSpill() const { return spill_; } + const std::string &getHostname() const { return hostname_; } + int32_t getPort() const { return port_; } + int32_t getNumWorkers() const { return num_workers_; } + int32_t getPrefetchSize() const { return prefetch_size_; } + private: mutable RWLock mux_; uint64_t cache_mem_sz_; bool spill_; // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow // sharing of the cache. - uint32_t session_id_; - uint32_t cache_crc_; + CacheClientInfo cinfo_; // The server_connection_id_ is the actual id we use for operations after the cache is built connection_id_type server_connection_id_; // Some magic cookie returned from the cache server. std::string cookie_; + // Comm layer + bool local_bypass_; + std::string hostname_; + int32_t port_; + int32_t num_workers_; + int32_t prefetch_size_; + mutable std::shared_ptr comm_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h new file mode 100644 index 0000000000..cf1960075e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ + +/// \note This header file contains common header files and some inlines used by +/// both client and server side codes. Do not put code that is not common here. +/// There are client and server specific header files. + +// On platform like Windows, we may support only tcp/ip clients +#if !defined(_WIN32) && !defined(_WIN64) +#define CACHE_LOCAL_CLIENT 1 +#endif + +#ifdef CACHE_LOCAL_CLIENT +#include +#include +#include +#else +typedef int key_t; +#endif +#ifdef ENABLE_CACHE +#include +#endif +#include +#ifdef ENABLE_CACHE +#include "proto/cache_grpc.grpc.pb.h" +#endif +#include "proto/cache_grpc.pb.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/engine/cache/de_tensor_generated.h" +namespace mindspore { +namespace dataset { +/// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported +/// on the platform) when the amount of bytes sent is greater than the following number. +/// For too small amount, we won't get any benefit using shared memory method because we need +/// two rpc requests to use shared memory method. +constexpr static int32_t kLocalByPassThreshold = 64 * 1024; +/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass +constexpr static uint32_t kLocalClientSupport = 1; +/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is +/// inline in the protobuf. This also implies kLocalClientSupport is also true. +constexpr static uint32_t kDataIsInSharedMemory = 2; + +/// \brief Convert a Status object into a protobuf +/// \param rc[in] Status object +/// \param reply[in/out] pointer to pre-allocated protobuf object +inline void Status2CacheReply(const Status &rc, CacheReply *reply) { + reply->set_rc(static_cast(rc.get_code())); + reply->set_msg(rc.ToString()); +} + +/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number +/// \param port +/// \return unix socket url +inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } + +/// \brief Generate a shared memory key using the tcp/ip port. +/// \note It must be called after the cache server generates the unix socket or ftok will fail. +/// \note Caller must check the return value. -1 means ftok failed. +/// \param[in] port +/// \param[out] err. If not null and ftok fails, this will contain the value of errno +/// \return key +inline key_t PortToFtok(int port, int *err) { + key_t shmkey = -1; +#ifdef CACHE_LOCAL_CLIENT + const std::string unix_path = PortToUnixSocketPath(port); + shmkey = ftok(unix_path.data(), 'a'); + if (err != nullptr && shmkey == (key_t)-1) { + *err = errno; + } +#endif + return shmkey; +} +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc new file mode 100644 index 0000000000..7a49dfc237 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_fbb.h" +namespace mindspore { +namespace dataset { +/// A private function used by SerializeTensorRowHeader to serialize each column in a tensor +/// \note Not to be called by outside world +/// \return Status object +Status SerializeOneTensorMeta(const std::shared_ptr &fbb, + const std::shared_ptr &ts_ptr, flatbuffers::Offset *out_off) { + RETURN_UNEXPECTED_IF_NULL(out_off); + const Tensor *ts = ts_ptr.get(); + auto shape_off = fbb->CreateVector(ts->shape().AsVector()); + const auto ptr = ts->GetBuffer(); + if (ptr == nullptr) { + RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); + } + auto src = ts->type().value(); + TensorType dest; +#define CASE(t) \ + case DataType::t: \ + dest = TensorType::TensorType_##t; \ + break + // Map the type to fill in the flat buffer. + switch (src) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + default: + MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; + RETURN_STATUS_UNEXPECTED("Unknown type"); + } +#undef CASE + + TensorMetaMsgBuilder ts_builder(*fbb); + ts_builder.add_dims(shape_off); + ts_builder.add_type(dest); + auto ts_off = ts_builder.Finish(); + *out_off = ts_off; + return Status::OK(); +} + +Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr *out_fbb) { + RETURN_UNEXPECTED_IF_NULL(out_fbb); + auto fbb = std::make_shared(); + try { + fbb = std::make_shared(); + std::vector> v; + std::vector tensor_sz; + v.reserve(row.size()); + tensor_sz.reserve(row.size()); + // We will go through each column in the row. + for (const std::shared_ptr &ts_ptr : row) { + flatbuffers::Offset ts_off; + RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off)); + v.push_back(ts_off); + tensor_sz.push_back(ts_ptr->SizeInBytes()); + } + auto column_off = fbb->CreateVector(v); + auto data_sz_off = fbb->CreateVector(tensor_sz); + TensorRowHeaderMsgBuilder row_builder(*fbb); + row_builder.add_column(column_off); + row_builder.add_data_sz(data_sz_off); + // Pass the row_id even if it may not be known. + row_builder.add_row_id(row.getId()); + row_builder.add_size_of_this(-1); // fill in later after we call Finish. + auto out = row_builder.Finish(); + fbb->Finish(out); + // Now go back to fill in size_of_this in the flat buffer. + auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer()); + auto success = msg->mutate_size_of_this(fbb->GetSize()); + if (!success) { + RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); + } + (*out_fbb) = std::move(fbb); + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + +Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr *out) { + RETURN_UNEXPECTED_IF_NULL(col_ts); + auto shape_in = col_ts->dims(); + auto type_in = col_ts->type(); + std::vector v; + v.reserve(shape_in->size()); + v.assign(shape_in->begin(), shape_in->end()); + TensorShape shape(v); + DataType::Type dest = DataType::DE_UNKNOWN; +#define CASE(t) \ + case TensorType_##t: \ + dest = DataType::Type::t; \ + break + + switch (type_in) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + } +#undef CASE + + DataType type(dest); + std::shared_ptr ts; + RETURN_IF_NOT_OK( + Tensor::CreateFromMemory(shape, type, static_cast(data.GetPointer()), data.GetSize(), &ts)); + // Next we restore the real data which can be embedded or stored separately. + if (ts->SizeInBytes() != data.GetSize()) { + MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" + << "Dumping tensor\n" + << *ts << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + *out = std::move(ts); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h new file mode 100644 index 0000000000..36a3d63099 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ + +/// This header contains some serialize and deserialize functions for tensor row using +/// Google Flatbuffer + +#include +#include +#include +#include "minddata/dataset/engine/cache/de_tensor_generated.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +/// \brief Function to serialize TensorRow header used by CacheRowRequest +/// \param row TensorRow +/// \param fbb [in/out] fbb that contains the serialized data +/// \return Status object +Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr *fbb); + +/// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row. +/// \param col_ts A serialized version of Tensor meta data +/// \param data Tensor data wrapped in a slice +/// \param out Tensor +/// \return Status object +Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr *out); +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto new file mode 100644 index 0000000000..68619d33ab --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package mindspore.dataset; +option cc_enable_arenas = true; + +// The session_id and crc work together to uniquely identify this particular cache and allow +// sharing of the cache. +message CacheClientInfo { + uint32 session_id = 1; + uint32 crc = 2; +} + +message CacheRequest { + // Type of rpc request + int32 type = 1; + // Extra optional flag used by individual request if needed + uint32 flag = 2; + oneof connect_info { + // The server_connection_id is the actual id we use for operations after the cache is built + int64 connection_id = 3; + // But some request like CreateCache we have to use the session id and crc to connect to the server. + CacheClientInfo connection_info = 4; + } + // Everything else is just vector of buffers + repeated bytes buf_data = 5; +} + +message CacheReply { + int32 rc = 1; + string msg = 2; + // Extra optional flag used by individual request if needed + uint32 flag = 3; + // What the server send back is a plain buffer + bytes result = 4; +} + +service CacheServerGreeter { + rpc CacheServerRequest (CacheRequest) returns (CacheReply) {} +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc new file mode 100644 index 0000000000..33151201ea --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_grpc_client.h" +#include +namespace mindspore { +namespace dataset { +Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, + std::unique_ptr &&tag) { + // If there is anything extra we need to do before we send. + RETURN_IF_NOT_OK(tag->base_rq_->Prepare()); + // One minute timeout + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); + tag->ctx_.set_deadline(deadline); + tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq); + tag->rpc_->StartCall(); + // Last step is we release the ownership and transfer it to the completion queue. + // The memory will be released by WorkerEntry or by the destructor when we drain the queue + auto ccReqTag = tag.release(); + ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, + ccReqTag); // inject this object into the completion queue + return Status::OK(); +} + +CacheClientGreeter::~CacheClientGreeter() { + (void)ServiceStop(); + // Detach from shared memory if any + if (shmat_addr_ != nullptr) { + shmdt(shmat_addr_); + shmat_addr_ = nullptr; + } +} + +CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) + : num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) { + grpc::ChannelArguments args; + // We need to bump up the message size to unlimited. The default receiving + // message limit is 4MB which is not big enough. + args.SetMaxReceiveMessageSize(-1); +#if CACHE_LOCAL_CLIENT + // Try connect locally to the unix_socket first as the first preference + // Need to resolve hostname to ip address rather than to do a string compare + if (hostname == "127.0.0.1") { + std::string target = "unix://" + PortToUnixSocketPath(port); + channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); + } else { +#endif + std::string target = hostname + ":" + std::to_string(port); + channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); +#if CACHE_LOCAL_CLIENT + } +#endif + stub_ = CacheServerGreeter::NewStub(channel_); +} + +Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) { + *local_bypass = false; +#if CACHE_LOCAL_CLIENT + int err; + shm_key_ = PortToFtok(port, &err); + if (shm_key_ == (key_t)-1) { + std::string errMsg = "Ftok failed with errno " + std::to_string(err); + RETURN_STATUS_UNEXPECTED(errMsg); + } + // Attach to the shared memory + shm_id_ = shmget(shm_key_, 0, 0); + if (shm_id_ == -1) { + RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); + } + shmat_addr_ = shmat(shm_id_, nullptr, 0); + if (shmat_addr_ == reinterpret_cast(-1)) { + RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); + } + *local_bypass = true; +#endif + return Status::OK(); +} + +Status CacheClientGreeter::DoServiceStart() { + RETURN_IF_NOT_OK(vg_.ServiceStart()); + RETURN_IF_NOT_OK(DispatchWorkers(num_workers_)); + return Status::OK(); +} + +Status CacheClientGreeter::DoServiceStop() { + // Shutdown the queue. We don't accept any more new incomers. + cq_.Shutdown(); + // Shutdown the TaskGroup. + vg_.interrupt_all(); + vg_.join_all(Task::WaitFlag::kNonBlocking); + // Drain the queue + bool success; + void *tag; + while (cq_.Next(&tag, &success)) { + auto r = reinterpret_cast(tag); + delete r; + } + return Status::OK(); +} + +Status CacheClientGreeter::HandleRequest(std::shared_ptr rq) { + auto tag = std::make_unique(std::move(rq)); + return tag->MakeCall(stub_.get(), &cq_, std::move(tag)); +} + +Status CacheClientGreeter::WorkerEntry() { + TaskManager::FindMe()->Post(); + do { + bool success; + void *tag; + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); + // Set a timeout for one second. Check for interrupt if we need to do early exit. + auto r = cq_.AsyncNext(&tag, &success, deadline); + if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { + auto rq = reinterpret_cast(tag); + if (success) { + auto &rc = rq->rc_; + if (!rc.ok()) { + auto error_code = rq->rc_.error_code(); + std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); + Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + Status2CacheReply(remote_rc, &rq->base_rq_->reply_); + } + // Notify the waiting thread. + rq->Notify(); + } + // We can now free the memory + delete rq; + } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { + // If we are interrupted, exit. Otherwise wait again. + RETURN_IF_INTERRUPTED(); + } else { + // Queue is drained. + break; + } + } while (true); + return Status::OK(); +} + +Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) { + auto f = std::bind(&CacheClientGreeter::WorkerEntry, this); + for (auto i = 0; i < num_workers; ++i) { + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f)); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h new file mode 100644 index 0000000000..8fbd265bc3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ + +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/task_manager.h" +namespace mindspore { +namespace dataset { +/// \brief A client view of gRPC request +/// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC +/// completion queue. The thread that makes the rpc request will wait on a wait post +/// area for the reply to come back. Since this tag will be deleted from memory and +/// we thus we need to work on a shared pointer of the BaseRequest such that its +/// use count is at least two. Otherwise either thread will be referencing stale memory. +/// \see CacheServerRequest +class CacheClientRequestTag { + public: + friend class CacheClientGreeter; + explicit CacheClientRequestTag(std::shared_ptr rq) : base_rq_(std::move(rq)) {} + ~CacheClientRequestTag() = default; + + /// \brief Make a RPC call + /// \param stub from CacheClientGreeter + /// \param cq from CacheClientGreeter + /// \return Status object + static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, + std::unique_ptr &&tag); + + /// \brief Notify the client that a result has come back from the server + void Notify() { base_rq_->wp_.Set(); } + + private: + std::shared_ptr base_rq_; + grpc::Status rc_; + grpc::ClientContext ctx_; + std::unique_ptr> rpc_; +}; + +/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC +/// \see BaseRequest +class CacheClientGreeter : public Service { + friend class CacheClient; + + public: + explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers); + ~CacheClientGreeter(); + + /// Override base Service class + Status DoServiceStart() override; + Status DoServiceStop() override; + + /// \brief Send the request to the server + /// \return Status object + Status HandleRequest(std::shared_ptr rq); + + /// \brief A handful of threads will be handling async reply from the server + /// \return + Status WorkerEntry(); + + /// \brief Kick off threads to receive reply from the server + Status DispatchWorkers(int32_t num_workers); + + /// \brief Attach to shared memory for local client + /// \note Called after we have established a connection. + /// \return Status object. + Status AttachToSharedMemory(int32_t port, bool *local_bypass); + + /// \brief This returns where we attach to the shared memory. + /// \return Base address of the shared memory. + const void *SharedMemoryBaseAddr() const { return shmat_addr_; } + + private: + std::shared_ptr channel_; + std::unique_ptr stub_; + grpc::CompletionQueue cq_; + TaskGroup vg_; + int32_t num_workers_; + key_t shm_key_; + int32_t shm_id_; + void *shmat_addr_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc new file mode 100644 index 0000000000..43c0bcfe5a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc @@ -0,0 +1,203 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_grpc_server.h" +#include +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/util/path.h" +#include "utils/log_adapter.h" +namespace mindspore { +namespace dataset { +CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) + : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) { + // Setup a path for unix socket. + unix_socket_ = PortToUnixSocketPath(port); + // We can't generate the ftok key yet until the unix_socket_ is created +} + +void CacheServerGreeterImpl::Shutdown() { + if (server_) { + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); + server_->Shutdown(deadline); + } + // Always shutdown the completion queue after the server. + if (cq_) { + cq_->Shutdown(); + // We need to drain the queue. All the tag is coming from + // the Services pool which will be shutdown as well. So we + // ignore the tag. + void *tag; + bool success; + while (cq_->Next(&tag, &success)) { + } + } +} + +CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); } + +Status CacheServerGreeterImpl::IpcResourceCleanup() { +#if CACHE_LOCAL_CLIENT + int err; + auto shm_key = PortToFtok(port_, &err); + // We are expecting the unix path doesn't exist. + if (shm_key == (key_t)-1) { + return Status::OK(); + } + // Attach to the shared memory + auto shm_id = shmget(shm_key, 0, 0); + if (shm_id == -1) { + return Status::OK(); + } + struct shmid_ds ds {}; + auto inx = shmctl(shm_id, IPC_STAT, &ds); + if (inx == -1) { + std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id); + errMsg += "\nPlesae remove it manually using ipcrm -m command"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + if (ds.shm_nattch == 0) { + // Stale shared memory from last time. + // Remove both the memory and the socket path + inx = shmctl(shm_id, IPC_RMID, nullptr); + if (inx == -1) { + std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id); + errMsg += ". Errno :" + std::to_string(errno); + errMsg += "\nPlesae remove it manually using ipcrm -m command"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + Path p(unix_socket_); + (void)p.Remove(); + } else { + // Server is already up. + MS_LOG(ERROR) << "Cache server is already up and running"; + // We return a duplicate error. The main() will intercept + // and output a proper message + return Status(StatusCode::kDuplicateKey); + } +#endif + return Status::OK(); +} + +Status CacheServerGreeterImpl::Run() { + // To listen on all interfaces, use 0.0.0.0 + // Use 127.0.0.1 if just locally on the same machine. + std::string host("0.0.0.0"); // listen on all interfaces. + std::string server_address = host + ":" + std::to_string(port_); + grpc::ServerBuilder builder; + // Default message size for gRPC is 4MB. Increase it to 2g-1 + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + int port_tcpip = 0; +#if CACHE_LOCAL_CLIENT + int port_local = 0; + // Check if we need to do clean up on the shared memory if the server + // came down unexpectedly like SEGV + RETURN_IF_NOT_OK(IpcResourceCleanup()); + // We also optimize on local clients on the same machine using unix socket + builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); +#endif + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip); + builder.RegisterService(&svc_); + cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + if (server_) { + MS_LOG(INFO) << "Server listening on " << server_address; +#if CACHE_LOCAL_CLIENT + RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); + MS_LOG(INFO) << "Creation of local socket and shared memory successful"; +#endif + } else { + std::string errMsg = "Fail to start server. "; + if (port_tcpip != port_) { + errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + "."; + } +#if CACHE_LOCAL_CLIENT + if (port_local == 0) { + errMsg += " Unable to create unix socket " + unix_socket_ + "."; + } +#endif + RETURN_STATUS_UNEXPECTED(errMsg); + } + return Status::OK(); +} + +Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) { + bool success; + void *tag; + // We loop through the grpc queue. Each connection if successful + // will come back with our own tag which is an instance of CacheServerRequest + // and we simply call its functor. But first we need to create these instances + // and inject them into the grpc queue. + CacheServerRequest *p; + // Get a free tag from my free list. + RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p)); + RETURN_IF_NOT_OK((*p)(&svc_, cq_.get())); + do { + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); + // Set a timeout for one second. Check for interrupt if we need to do early exit. + auto r = cq_->AsyncNext(&tag, &success, deadline); + if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { + if (success) { + auto rq = static_cast(tag); + RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get())); + } + } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { + // If we are interrupted, exit. Otherwise wait again. + RETURN_IF_INTERRUPTED(); + } else { + // Queue is drained. + break; + } + } while (true); + return Status::OK(); +} + +Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) { + auto myQID = getQid(); + if (st_ == STATE::CREATE) { + st_ = STATE::PROCESS; + svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this); + } else if (st_ == STATE::PROCESS) { + // Get a new tag and handle the next request before we serve the current request. + // The tag will be recycled when its state is changed to FINISH + CacheServerRequest *next_rq; + RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq)); + RETURN_IF_NOT_OK((*next_rq)(svc, cq)); + // Now we continue with the current request. + // First thing we need to extract the type from the incoming request. + // When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN. + type_ = static_cast(rq_.type()); + // Now we pass the address of this instance to CacheServer's main loop. + MS_LOG(DEBUG) << "Handle request " << *this; + auto &cs = CacheServer::GetInstance(); + RETURN_IF_NOT_OK(cs.PushRequest(myQID, this)); + } else if (st_ == STATE::FINISH) { + MS_LOG(DEBUG) << *this << " Finished."; + // Return back to the free list. + RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this)); + } + return Status::OK(); +} + +void CacheServerRequest::Print(std::ostream &out) const { + if (rq_.has_connection_info()) { + out << "Session Id: " << rq_.connection_info().session_id() << " CRC: " << rq_.connection_info().crc(); + } else { + out << "Connection Id: " << rq_.connection_id(); + } + out << " "; + BaseRequest::Print(out); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h new file mode 100644 index 0000000000..ac3e648bf3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/engine/cache/cache_arena.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +/// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects +/// and this class is used to translate from protobuf to structures understood by CacheService class. +/// \see CacheService +class CacheServerRequest : public BaseRequest { + public: + friend class CacheServer; + enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; + explicit CacheServerRequest(int32_t queue_id) + : BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), + qid_(queue_id), + st_(STATE::CREATE), + responder_(&ctx_) {} + + ~CacheServerRequest() = default; + + /// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this + /// functor will translate each protobuf into some form understood by by CacheService class. + /// \param svc Async service + /// \param cq Completion queue + /// \return Status object + Status operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq); + + /// \brief Override the base class Print method + /// \param out + void Print(std::ostream &out) const override; + + /// \brief Getter of the queue id + /// \return The queue where the request should go to + int32_t getQid() const { return qid_; } + + private: + int32_t qid_; + Status rc_; + STATE st_; + grpc::ServerContext ctx_; + grpc::ServerAsyncResponseWriter responder_; +}; + +/// \brief Implementation of CacheServerGreeter +/// \note It is an async server +/// \see cache_grpc.proto +class CacheServerGreeterImpl final { + friend class CacheServer; + + public: + explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb); + virtual ~CacheServerGreeterImpl(); + /// \brief Brings up gRPC server + /// \return none + Status Run(); + /// \brief Entry function to handle cache server request + Status HandleRequest(int32_t worker_id); + + /// Return the shared memory pool. + /// \return Return the shared memory pool + CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } + + void Shutdown(); + + Status IpcResourceCleanup(); + + private: + int32_t port_; + size_t shm_pool_sz_in_gb_; + std::string unix_socket_; + CacheServerGreeter::AsyncService svc_; + std::unique_ptr cq_; + std::unique_ptr server_; + std::unique_ptr shm_pool_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc new file mode 100644 index 0000000000..3de7b67110 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "minddata/dataset/engine/cache/cache_server.h" +#include +#include +#ifdef USE_GLOG +#include +#endif +#include + +namespace ds = mindspore::dataset; + +int main(int argc, char **argv) { + ds::Status rc; + ds::CacheServer::Builder builder; + + // This executable is not to be called directly, and should be invoked by cache_admin executable. + if (argc != 7) { + rc = ds::Status(ds::StatusCode::kSyntaxError); + std::cerr << rc.ToString() << std::endl; + return static_cast(rc.get_code()); + } + + builder.SetRootDirectory(argv[1]) + .SetNumWorkers(strtol(argv[2], nullptr, 10)) + .SetPort(strtol(argv[3], nullptr, 10)) + .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)); + +#ifdef USE_GLOG + FLAGS_minloglevel = strtol(argv[5], nullptr, 10); +#endif + + auto daemonize_string = argv[6]; + bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 || + strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0; + + // We always change directory to / on unix rather than using the directory where the cache_server + // is called. This is a standard procedure for daemonize a process on unix. + if (chdir("/") == -1) { + std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); + std::cerr << errMsg << std::endl; + return -1; + } + + // Simple check of the parameters before we move on. + rc = builder.SanityCheck(); + if (rc.IsError()) { + std::cerr << rc.ToString() << std::endl; + return static_cast(rc.get_code()); + } + +#ifdef USE_GLOG + FLAGS_log_dir = "/tmp"; + google::InitGoogleLogging(argv[0]); +#endif + + if (daemonize) { + // fork the child process to become the daemon + pid_t pid = fork(); + // failed to fork + if (pid < 0) { + std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); + std::cerr << err_msg << std::endl; + return errno; + } else if (pid > 0) { + // Parent + std::cerr << "cache server daemon process has been created as process id: " << pid + << "\nCheck log file for any start up error" << std::endl; + signal(SIGCHLD, SIG_IGN); // ignore sig child signal. + return 0; + } else { + // Child process will continue from here if daemonize and parent has already exited. + // If we are running in the foreground, none of the code in block below will be run. + pid_t sid; + umask(0); + sid = setsid(); + if (sid < 0) { + MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno); + return errno; + } + close(0); + close(1); + close(2); + } + } + + // Dump the summary + MS_LOG(INFO) << builder << std::endl; + rc = builder.Build(); + if (rc.IsOk()) { + ds::CacheServer &cs = ds::CacheServer::GetInstance(); + // Kick off the threads. Loop forever and never return unless error. + rc = cs.Run(); + if (rc.get_code() == ds::StatusCode::kDuplicateKey) { + std::string errMsg = "Server is already started"; + MS_LOG(ERROR) << errMsg; + std::cerr << errMsg << std::endl; + return 0; + } + } + if (rc.IsError()) { + MS_LOG(ERROR) << rc.ToString(); + std::cerr << rc.ToString() << std::endl; + return static_cast(rc.get_code()); + } + return 0; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index a460e43aea..fc69a7eeab 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -14,154 +14,149 @@ * limitations under the License. */ #include "minddata/dataset/engine/cache/cache_request.h" - +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/cache_fbb.h" namespace mindspore { namespace dataset { - -Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { - buffers_.reserve(row.size() + 1); - RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); - buffers_.push_back(fbb_->GetBufferPointer()); - for (const auto &ts : row) { - buffers_.push_back(ts->GetBuffer()); - } +Status BaseRequest::Wait() { + RETURN_IF_NOT_OK(wp_.Wait()); + Status remote_rc(static_cast(reply_.rc()), reply_.msg()); + RETURN_IF_NOT_OK(remote_rc); + // Any extra work to do before we return back to the client. + RETURN_IF_NOT_OK(PostReply()); return Status::OK(); } - -Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { - try { - fbb_ = std::make_shared(); - std::vector> v; - std::vector tensor_sz; - v.reserve(row.size()); - tensor_sz.reserve(row.size()); - // We will go through each column in the row. - for (const std::shared_ptr &ts_ptr : row) { - flatbuffers::Offset ts_off; - RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); - v.push_back(ts_off); - tensor_sz.push_back(ts_ptr->SizeInBytes()); +Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) { + CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row"); + CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch"); + // Calculate how many bytes (not counting the cookie) we are sending to the server. We only + // use shared memory (if supported) if we exceed certain amount + std::shared_ptr fbb; + RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb)); + sz_ += fbb->GetSize(); + for (const auto &ts : row) { + sz_ += ts->SizeInBytes(); + } + bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false; + uint32_t flag = 0; + if (support_local_bypass_) { + BitSet(&flag, kLocalClientSupport); + } + if (sent_using_local_bypass) { + BitSet(&flag, kDataIsInSharedMemory); + } + rq_.set_flag(flag); + if (sent_using_local_bypass) { + MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data"; + // Allocate shared memory from the server + auto mem_rq = std::make_shared(rq_.connection_id(), sz_); + RETURN_IF_NOT_OK(cc->PushRequest(mem_rq)); + RETURN_IF_NOT_OK(mem_rq->Wait()); + addr_ = mem_rq->GetAddr(); + // Now we need to add that to the base address of where we attach. + auto base = cc->SharedMemoryBaseAddr(); + auto p = reinterpret_cast(reinterpret_cast(base) + addr_); + // Now we copy the data onto shared memory. + WritableSlice all(p, sz_); + auto offset = fbb->GetSize(); + ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize()); + Status copy_rc; + copy_rc = WritableSlice::Copy(&all, header); + if (copy_rc.IsOk()) { + for (const auto &ts : row) { + WritableSlice row_data(all, offset, ts->SizeInBytes()); + ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes()); + copy_rc = WritableSlice::Copy(&row_data, src); + if (copy_rc.IsError()) { + break; + } + offset += ts->SizeInBytes(); + } + // Fill in where to find the data + AddDataLocation(); } - auto column_off = fbb_->CreateVector(v); - auto data_sz_off = fbb_->CreateVector(tensor_sz); - TensorRowHeaderMsgBuilder row_builder(*fbb_); - row_builder.add_column(column_off); - row_builder.add_data_sz(data_sz_off); - // Pass the row_id even if it may not be known. - row_builder.add_row_id(row.getId()); - row_builder.add_size_of_this(-1); // fill in later after we call Finish. - auto out = row_builder.Finish(); - fbb_->Finish(out); - // Now go back to fill in size_of_this in the flat buffer. - auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); - auto success = msg->mutate_size_of_this(fbb_->GetSize()); - if (!success) { - RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); + if (copy_rc.IsError()) { + // We need to return the memory back to the server + auto mfree_req = GenerateFreeBlockRequest(); + Status rc = cc->PushRequest(mfree_req); + // But we won't wait for the result for the sake of performance. + if (rc.IsError()) { + MS_LOG(ERROR) << "Push request for free memory failed."; + } + return copy_rc; } - return Status::OK(); - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } else { + // We have already filled the first buffer which is the cookie. + sz_ += rq_.buf_data(0).size(); + rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize()); + for (const auto &ts : row) { + rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes()); + } + MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments"; } + return Status::OK(); } -Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, - flatbuffers::Offset *out_off) { - RETURN_UNEXPECTED_IF_NULL(out_off); - const Tensor *ts = ts_ptr.get(); - auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); - const auto ptr = ts->GetBuffer(); - if (ptr == nullptr) { - RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); - } - auto src = ts->type().value(); - TensorType dest; -#define CASE(t) \ - case DataType::t: \ - dest = TensorType::TensorType_##t; \ - break - // Map the type to fill in the flat buffer. - switch (src) { - CASE(DE_BOOL); - CASE(DE_INT8); - CASE(DE_UINT8); - CASE(DE_INT16); - CASE(DE_UINT16); - CASE(DE_INT32); - CASE(DE_UINT32); - CASE(DE_INT64); - CASE(DE_UINT64); - CASE(DE_FLOAT16); - CASE(DE_FLOAT32); - CASE(DE_FLOAT64); - CASE(DE_STRING); - default: - MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; - RETURN_STATUS_UNEXPECTED("Unknown type"); +Status CacheRowRequest::PostReply() { + if (!reply_.result().empty()) { + row_id_from_server_ = strtoll(reply_.result().data(), nullptr, 10); } -#undef CASE - - TensorMetaMsgBuilder ts_builder(*fbb_); - ts_builder.add_dims(shape_off); - ts_builder.add_type(dest); - auto ts_off = ts_builder.Finish(); - *out_off = ts_off; return Status::OK(); } - -Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, - std::shared_ptr *out) { - RETURN_UNEXPECTED_IF_NULL(col_ts); - auto shape_in = col_ts->dims(); - auto type_in = col_ts->type(); - std::vector v; - v.reserve(shape_in->size()); - v.assign(shape_in->begin(), shape_in->end()); - TensorShape shape(v); - DataType::Type dest = DataType::DE_UNKNOWN; -#define CASE(t) \ - case TensorType_##t: \ - dest = DataType::Type::t; \ - break - - switch (type_in) { - CASE(DE_BOOL); - CASE(DE_INT8); - CASE(DE_UINT8); - CASE(DE_INT16); - CASE(DE_UINT16); - CASE(DE_INT32); - CASE(DE_UINT32); - CASE(DE_INT64); - CASE(DE_UINT64); - CASE(DE_FLOAT16); - CASE(DE_FLOAT32); - CASE(DE_FLOAT64); - CASE(DE_STRING); +Status CacheRowRequest::Prepare() { + if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { + // First one is cookie, followed by address and then size. + CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == 3, "Incomplete rpc data"); + } else { + // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers. + // But we are not going to decode them to verify. + CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= 3, "Incomplete rpc data"); } -#undef CASE - - DataType type(dest); - std::shared_ptr ts; - RETURN_IF_NOT_OK( - Tensor::CreateFromMemory(shape, type, static_cast(data.GetPointer()), data.GetSize(), &ts)); - // Next we restore the real data which can be embedded or stored separately. - if (ts->SizeInBytes() != data.GetSize()) { - MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" - << "Dumping tensor\n" - << *ts << "\n"; - RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); - } - *out = std::move(ts); return Status::OK(); } -Status BatchFetchRequest::RestoreRows(TensorTable *out) { +BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id, + bool local_bypass) + : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) { + rq_.set_connection_id(connection_id); + rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0); + // Convert the row id into a flatbuffer + flatbuffers::FlatBufferBuilder fbb; + auto off_t = fbb.CreateVector(row_id); + TensorRowIdsBuilder bld(fbb); + bld.add_row_id(off_t); + auto off = bld.Finish(); + fbb.Finish(off); + rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); +} + +Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) { RETURN_UNEXPECTED_IF_NULL(out); auto num_elements = row_id_.size(); - auto *offset_array = reinterpret_cast(mem_.GetPointer()); + const char *ptr = nullptr; + int64_t sz = 0; + // Tap into the reply flag to see where we can find the data. Server may decide the amount is + // so small that it doesn't use shared memory method. + auto flag = reply_.flag(); + bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false; + if (dataOnSharedMemory) { + auto addr = strtoll(reply_.result().data(), nullptr, 10); + ptr = reinterpret_cast(reinterpret_cast(baseAddr) + addr); + RETURN_UNEXPECTED_IF_NULL(out); + *out_addr = addr; + } else { + ptr = reply_.result().data(); + *out_addr = -1; + } + auto *offset_array = reinterpret_cast(ptr); + sz = offset_array[num_elements]; + CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch"); TensorTable tbl; tbl.reserve(num_elements); - ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); + ReadableSlice all(ptr, sz); for (auto i = 0; i < num_elements; ++i) { auto len = offset_array[i + 1] - offset_array[i]; TensorRow row; @@ -178,10 +173,12 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) { auto col_ts = msg->column()->Get(k); std::shared_ptr ts; ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); - RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); + RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts)); row.push_back(ts); ts_offset += data.GetSize(); } + } else { + CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected."); } tbl.push_back(std::move(row)); } @@ -189,36 +186,69 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) { return Status::OK(); } +CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, + CreateCacheRequest::CreateCacheFlag flag) + : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) { + // Type has been set already in the base constructor. So we need to fill in the connection info. + // On successful return, we will get the connection id + rq_.mutable_connection_info()->operator=(cinfo); +} + +Status CreateCacheRequest::Prepare() { + try { + flatbuffers::FlatBufferBuilder fbb; + CreateCacheRequestMsgBuilder bld(fbb); + bld.add_cache_mem_sz(cache_mem_sz_); + bld.add_flag(static_cast(flag_)); + auto off = bld.Finish(); + fbb.Finish(off); + rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map &map) { try { - fbb_ = std::make_shared(); + flatbuffers::FlatBufferBuilder fbb; std::vector> v; v.reserve(map.size()); for (auto &column : map) { - auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); + auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second); v.push_back(c); } - auto v_off = fbb_->CreateVector(v); - auto final_off = CreateSchemaMsg(*fbb_, v_off); - fbb_->Finish(final_off); - buf_ = fbb_->GetBufferPointer(); - len_of_buf_ = fbb_->GetSize(); + auto v_off = fbb.CreateVector(v); + auto final_off = CreateSchemaMsg(fbb, v_off); + fbb.Finish(final_off); + rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); return Status::OK(); } catch (const std::bad_alloc &e) { return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); } } -std::unordered_map FetchSchemaRequest::GetColumnMap() { - if (column_name_id_map_.empty()) { - auto *map_msg = flatbuffers::GetRoot(mem_.GetPointer()); - auto v = map_msg->column(); - for (auto i = 0; i < v->size(); ++i) { - auto col = map_msg->column()->Get(i); - column_name_id_map_.emplace(col->name()->str(), col->id()); - } +Status FetchSchemaRequest::PostReply() { + auto *map_msg = flatbuffers::GetRoot(reply_.result().data()); + auto v = map_msg->column(); + for (auto i = 0; i < v->size(); ++i) { + auto col = map_msg->column()->Get(i); + column_name_id_map_.emplace(col->name()->str(), col->id()); } - return column_name_id_map_; + return Status::OK(); +} + +std::unordered_map FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; } + +Status GetStatRequest::PostReply() { + auto *msg = flatbuffers::GetRoot(reply_.result().data()); + stat_.num_disk_cached = msg->num_disk_cached(); + stat_.num_mem_cached = msg->num_mem_cached(); + stat_.avg_cache_sz = msg->avg_cache_sz(); + stat_.max_row_id = msg->max_row_id(); + stat_.min_row_id = msg->min_row_id(); + stat_.cache_service_state = msg->state(); + return Status::OK(); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index 6851cebe0c..4a2edcd136 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -18,11 +18,16 @@ #include #include +#include #include #include #include #include +#ifdef ENABLE_CACHE +#include "proto/cache_grpc.grpc.pb.h" +#endif +#include "proto/cache_grpc.pb.h" #include "minddata/dataset/core/tensor_row.h" #include "minddata/dataset/engine/cache/de_tensor_generated.h" #include "minddata/dataset/util/slice.h" @@ -30,6 +35,17 @@ namespace mindspore { namespace dataset { +class CacheClient; +/// \brief Statistic structure for GetStat request +struct CacheServiceStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + int64_t avg_cache_sz; + row_id_type min_row_id; + row_id_type max_row_id; + int8_t cache_service_state; +}; + /// \brief CacheClient communicates with CacheServer using Requests. class BaseRequest { public: @@ -44,195 +60,301 @@ class BaseRequest { kCacheSchema = 6, kFetchSchema = 7, kBuildPhaseDone = 8, + kDropSession = 9, + kGenerateSessionId = 10, + kAllocateSharedBlock = 11, + kFreeSharedBlock = 12, + kStopService = 13, // Add new request before it. kRequestUnknown = 32767 }; - // For kCreateCache - enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; + friend class CacheServer; + friend class CacheServerRequest; + friend class CacheClientGreeter; + friend class CacheClientRequestTag; + /// \brief Base class of a cache server request - /// \param connection_id A combination of session id and crc that uniquely identifies a connection. /// \param type Type of the request - explicit BaseRequest(connection_id_type connection_id, RequestType type) - : type_(type), connection_id_(connection_id) {} + explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast(type_)); } virtual ~BaseRequest() = default; - /// \brief Wait for the completion of a request - /// \return Status returned from the cache server - Status Wait() { - RETURN_IF_NOT_OK(wp_.Wait()); - return rc_; + + /// \brief A print method for debugging + /// \param out The output stream to write output to + virtual void Print(std::ostream &out) const { out << "Request type: " << static_cast(type_); } + + /// \brief << Stream output operator overload + /// \param out reference to the output stream + /// \param rq reference to the BaseRequest + /// \return the output stream + friend std::ostream &operator<<(std::ostream &out, const BaseRequest &rq) { + rq.Print(out); + return out; } - /// \brief Getter function of the current connection id - /// \return Connection id - connection_id_type GetServerConnectionId() const { return connection_id_; } + /// \brief Derived class can implement extra work to be done before the request is sent to the server + virtual Status Prepare() { return Status::OK(); } + + /// \brief Derived class can implement extra work to be done after the server sends the request + virtual Status PostReply() { return Status::OK(); } + + /// \brief A method for the client to wait for the availability of the result back from the server. + /// \return Status object + Status Wait(); + + protected: + CacheRequest rq_; // This is what we send to the server + CacheReply reply_; // This is what the server send back private: RequestType type_; - connection_id_type connection_id_; - Status rc_; - WaitPost wp_; + WaitPost wp_; // A sync area used by the client side. }; + +class FreeSharedBlockRequest : public BaseRequest { + public: + friend class CacheServer; + explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr) + : BaseRequest(RequestType::kFreeSharedBlock) { + rq_.set_connection_id(connection_id); + rq_.add_buf_data(std::to_string(addr)); + } + ~FreeSharedBlockRequest() = default; +}; + /// \brief Request to cache a single TensorRow class CacheRowRequest : public BaseRequest { public: friend class CacheServer; - explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) - : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} + friend class CacheClient; + explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass) + : BaseRequest(RequestType::kCacheRow), + support_local_bypass_(local_bypass), + addr_(-1), + sz_(0), + row_id_from_server_(-1) { + rq_.set_connection_id(connection_id); + rq_.add_buf_data(cookie); + } ~CacheRowRequest() = default; /// \brief Serialize a TensorRow for streaming to the cache server /// \param row TensorRow /// \return Status object - Status SerializeCacheRowRequest(const TensorRow &row); + Status SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row); + + /// \brief Sanity check before we send the row. + /// \return Status object + Status Prepare() override; + + /// \brief Override the base function get the row id returned from the server + /// \return Status object + Status PostReply() override; + /// \brief Return the row id assigned to this row for non-mappable dataset /// \return row id of the cached row row_id_type GetRowIdAfterCache() { return row_id_from_server_; } + /// \brief If we are doing local bypass, fill in extra request information of where the data is located. + void AddDataLocation() { + if (support_local_bypass_) { + rq_.add_buf_data(std::to_string(addr_)); + rq_.add_buf_data(std::to_string(sz_)); + } + } + + /// \brief If we fail to send the data to the server using shared memory method, we should release + /// the shared memory by sending another request. The following function will generate a suitable + /// request for the CacheClient to send. + std::shared_ptr GenerateFreeBlockRequest() { + return std::make_shared(rq_.connection_id(), addr_); + } + private: - std::shared_ptr fbb_; + bool support_local_bypass_; + int64_t addr_; + int64_t sz_; row_id_type row_id_from_server_; - std::vector buffers_; - std::string cookie_; - - /// \brief Private function to serialize one TensorRow - /// \param row TensorRow - /// \return Status object - Status SerializeTensorRowHeader(const TensorRow &row); - /// \brief Private function to serialize one Tensor - /// \param ts_ptr Tensor - /// \return Status object - Status SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, flatbuffers::Offset *out_off); }; + /// \brief Request to fetch rows in batch class BatchFetchRequest : public BaseRequest { public: friend class CacheServer; friend class CacheService; - BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id) - : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} + BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id, bool local_bypass); ~BatchFetchRequest() = default; - Status RestoreRows(TensorTable *out); + Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); private: + bool support_local_bypass_; std::vector row_id_; - MemGuard mem_; - Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr *out); }; + /// \brief Request to create a cache for the current connection -class CreationCacheRequest : public BaseRequest { +class CreateCacheRequest : public BaseRequest { public: friend class CacheServer; + enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; + /// \brief Constructor /// \param connection_id /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited /// \param flag Attributes of the cache. - explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, - CreateCacheFlag flag = CreateCacheFlag::kNone) - : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} - - ~CreationCacheRequest() = default; + explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, + CreateCacheFlag flag = CreateCacheFlag::kNone); + ~CreateCacheRequest() = default; + void ParseResult(connection_id_type *id, std::string *out) { + auto p = flatbuffers::GetRoot(reply_.result().data()); + *id = p->connection_id(); + *out = p->cookie()->str(); + } - std::string cookie() const { return cookie_; } + /// Overload the base class Prepare + Status Prepare() override; private: - uint64_t cache_mem_sz; + uint64_t cache_mem_sz_; CreateCacheFlag flag_; - std::string cookie_; }; + /// \brief Request to purge a cache. class PurgeCacheRequest : public BaseRequest { public: friend class CacheServer; - explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} - + explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) { + rq_.set_connection_id(connection_id); + } ~PurgeCacheRequest() = default; }; + /// \brief Request to destroy a cache class DestroyCacheRequest : public BaseRequest { public: friend class CacheServer; - explicit DestroyCacheRequest(connection_id_type connection_id) - : BaseRequest(connection_id, RequestType::kDestroyCache) {} - - /// \brief Destructor + explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { + rq_.set_connection_id(connection_id); + } ~DestroyCacheRequest() = default; }; + /// \brief Obtain the statistics of the current connection class GetStatRequest : public BaseRequest { public: friend class CacheServer; friend class CacheService; - explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} + explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetStat) { + rq_.set_connection_id(connection_id); + } ~GetStatRequest() = default; - row_id_type GetMinRowId() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->min_row_id(); - } - row_id_type GetMaxRowId() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->max_row_id(); - } - int64_t GetNumMemCached() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->num_mem_cached(); - } - int64_t GetNumDiskCached() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->num_disk_cached(); - } - uint8_t GetState() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->state(); + /// \brief Override base function to process the result. + Status PostReply() override; + + void GetStat(CacheServiceStat *stat) { + if (stat != nullptr) { + (*stat) = stat_; + } } private: - MemGuard mem_; + CacheServiceStat stat_{}; }; + /// \brief Request to cache a schema class CacheSchemaRequest : public BaseRequest { public: friend class CacheServer; - explicit CacheSchemaRequest(connection_id_type connection_id) - : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} + explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { + rq_.set_connection_id(connection_id); + } ~CacheSchemaRequest() = default; Status SerializeCacheSchemaRequest(const std::unordered_map &map); - const void *GetBuffer() const { return buf_; } - - private: - std::shared_ptr fbb_; - const void *buf_; - int64_t len_of_buf_; }; + /// \brief Request to fetch a schema class FetchSchemaRequest : public BaseRequest { public: friend class CacheServer; - explicit FetchSchemaRequest(connection_id_type connection_id) - : BaseRequest(connection_id, RequestType::kFetchSchema) {} + explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { + rq_.set_connection_id(connection_id); + } ~FetchSchemaRequest() = default; + Status PostReply() override; + std::unordered_map GetColumnMap(); private: - MemGuard mem_; std::unordered_map column_name_id_map_; }; + /// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. class BuildPhaseDoneRequest : public BaseRequest { public: friend class CacheServer; BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) - : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} - + : BaseRequest(RequestType::kBuildPhaseDone), cookie_(cookie) { + rq_.set_connection_id(connection_id); + rq_.add_buf_data(cookie_); + } ~BuildPhaseDoneRequest() = default; private: std::string cookie_; }; + +/// \brief Request to drop all the caches in the current session +class DropSessionRequest : public BaseRequest { + public: + friend class CacheServer; + explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { + rq_.mutable_connection_info()->operator=(cinfo); + } + ~DropSessionRequest() = default; +}; + +class GenerateSessionIdRequest : public BaseRequest { + public: + friend class CacheServer; + GenerateSessionIdRequest() : BaseRequest(RequestType::kGenerateSessionId) { + // We don't have anything client info nor connection id to send. But we will manually + // set the connection id to 0. + rq_.set_connection_id(0); + } + + ~GenerateSessionIdRequest() = default; + + session_id_type GetSessionId() { return atoi(reply_.result().data()); } +}; + +class AllocateSharedBlockRequest : public BaseRequest { + public: + friend class CacheServer; + explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz) + : BaseRequest(RequestType::kAllocateSharedBlock) { + rq_.set_connection_id(connection_id); + rq_.add_buf_data(std::to_string(requestedSz)); + } + ~AllocateSharedBlockRequest() = default; + + /// \brief On return from the server, we get the (relative) address where + /// the free block is located. + /// \return + int64_t GetAddr() { + auto addr = strtoll(reply_.result().data(), nullptr, 10); + return addr; + } +}; + +class ShutdownRequest : public BaseRequest { + public: + friend class CacheServer; + ShutdownRequest() : BaseRequest(RequestType::kStopService) {} + ~ShutdownRequest() = default; +}; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index c9fb6ecab1..c181376c76 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -14,25 +14,89 @@ * limitations under the License. */ #include "minddata/dataset/engine/cache/cache_server.h" +#include +#include +#include +#include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/util/bit.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/random.h" +#ifdef CACHE_LOCAL_CLIENT +#include "minddata/dataset/util/sig_handler.h" +#endif namespace mindspore { namespace dataset { +CacheServer *CacheServer::instance_ = nullptr; +std::once_flag CacheServer::init_instance_flag_; Status CacheServer::DoServiceStart() { +#ifdef CACHE_LOCAL_CLIENT + // We need to destroy the shared memory if user hits Control-C + RegisterHandlers(); +#endif if (!top_.empty()) { Path spill(top_); RETURN_IF_NOT_OK(spill.CreateDirectories()); MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; } RETURN_IF_NOT_OK(vg_.ServiceStart()); - cache_q_ = std::make_shared>(1024); + // There will be num_workers_ threads working on the grpc queue and + // the same number of threads working on the CacheServerRequest queue. + // Like a connector object we will set up the same number of queues but + // we do not need to preserve any order. We will set the capacity of + // each queue to be 128 since we are just pushing memory pointers which + // is only 8 byte each. + const int32_t que_capacity = 128; + // This is the request queue from the client + cache_q_ = std::make_shared>(); + cache_q_->Init(num_workers_, que_capacity); + // For the grpc completion queue to work, we need to allocate some + // tags which in our case are instances of CacheServerQuest. + // They got recycled and we will allocate them in advance and push + // them into some free list. We need more (two or three times) the + // size of the cache_q. While each worker is working on a CacheSerRequest, + // we need some extra running injecting in the the qrpc completion queue. + const int32_t multiplier = 3; + const int32_t free_list_capacity = multiplier * (que_capacity + 1); + free_list_ = std::make_shared>(); + free_list_->Init(num_workers_, free_list_capacity); + // We need to have a reference to the services memory pool in case + // the Services goes out of scope earlier than us since it is a singleton + mp_ = Services::GetInstance().GetServiceMemPool(); + Allocator alloc(mp_); + tag_.reserve(num_workers_); + // Now we populate all free list. + for (auto m = 0; m < num_workers_; ++m) { + // Ideally we allocate all the free list in one malloc. But it turns out it exceeds the + // Arena size. So we will we will allocate one segment at a time. + auto my_tag = std::make_unique>>(alloc); + // Allocate the tag and assign it the current queue + RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m)); + for (int i = 0; i < free_list_capacity; ++i) { + RETURN_IF_NOT_OK(free_list_->operator[](m)->Add((*my_tag)[i])); + } + tag_.push_back(std::move(my_tag)); + } RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); - auto f = std::bind(&CacheServer::ServerRequest, this); - // Spawn a a few threads to serve the request. + RETURN_IF_NOT_OK(free_list_->Register(&vg_)); + // Spawn a few threads to serve the real request. + auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); + for (auto i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i))); + } + // Start the comm layer + try { + comm_layer_ = std::make_shared(port_, shared_memory_sz_in_gb_); + RETURN_IF_NOT_OK(comm_layer_->Run()); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + // Finally loop forever to handle the request. + auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1); for (auto i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i))); } return Status::OK(); } @@ -65,188 +129,551 @@ CacheService *CacheServer::GetService(connection_id_type id) const { return nullptr; } -Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, - BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { +Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { + CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); + std::string cookie; + auto session_id = rq->connection_info().session_id(); + auto crc = rq->connection_info().crc(); + // We concat both numbers to form the internal connection id. + auto connection_id = GetConnectionID(session_id, crc); + CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache"); + auto &create_cache_buf = rq->buf_data(0); + auto p = flatbuffers::GetRoot(create_cache_buf.data()); + auto flag = static_cast(p->flag()); + auto cache_mem_sz = p->cache_mem_sz(); // We can't do spilling unless this server is setup with a spill path in the first place - bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; + bool spill = + (flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk; bool generate_id = - (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; + (flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId; if (spill && top_.empty()) { RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); } - RETURN_UNEXPECTED_IF_NULL(out_cookie); - *out_cookie = ""; + flatbuffers::FlatBufferBuilder fbb; + flatbuffers::Offset off_cookie; // Before creating the cache, first check if this is a request for a shared usage of an existing cache // If two CreateService come in with identical connection_id, we need to serialize the create. // The first create will be successful and be given a special cookie. UniqueLock lck(&rwLock_); + // Early exit if we are doing global shutdown + if (global_shutdown_) { + return Status::OK(); + } auto end = all_caches_.end(); auto it = all_caches_.find(connection_id); + bool duplicate = false; if (it == end) { std::unique_ptr cs; try { cs = std::make_unique(cache_mem_sz, spill ? top_ : "", generate_id); RETURN_IF_NOT_OK(cs->ServiceStart()); - *out_cookie = cs->cookie(); + cookie = cs->cookie(); all_caches_.emplace(connection_id, std::move(cs)); } catch (const std::bad_alloc &e) { return Status(StatusCode::kOutOfMemory); } } else { + duplicate = true; MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; - // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it - // treat it as OK. - return Status(StatusCode::kDuplicateKey); + } + off_cookie = fbb.CreateString(cookie); + CreateCacheReplyMsgBuilder bld(fbb); + bld.add_connection_id(connection_id); + bld.add_cookie(off_cookie); + auto off = bld.Finish(); + fbb.Finish(off); + reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); + // Track the history of all the sessions that we have created so far. + history_sessions_.insert(session_id); + // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it + // treat it as OK. + return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); +} + +Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { + // We need a strong lock to protect the map. + UniqueLock lck(&rwLock_); + // it is already destroyed. Ignore it. + if (cs != nullptr) { + auto id = rq->connection_id(); + MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); + // std::map will invoke the destructor of CacheService. So we don't need to do anything here. + auto n = all_caches_.erase(id); + if (n == 0) { + // It has been destroyed by another duplicate request. + MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; + } + } + return Status::OK(); +} + +inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { + auto connection_id = rq->connection_id(); + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto sz = rq->buf_data_size(); + std::vector buffers; + buffers.reserve(sz); + // First piece of data is the cookie and is required + CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); + auto &cookie = rq->buf_data(0); + // Only if the cookie matches, we can accept insert into this cache that has a build phase + if (!cs->HasBuildPhase() || cookie == cs->cookie()) { + // Push the address of each buffer (in the form of std::string coming in from protobuf) into + // a vector of buffer + for (auto i = 1; i < sz; ++i) { + buffers.push_back(rq->buf_data(i).data()); + } + row_id_type id = -1; + RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); + reply->set_result(std::to_string(id)); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } } return Status::OK(); } -/// This is the main loop the cache server thread(s) are running. -/// Each thread will pop a request and save the result in the same request. -/// The sender will wait on the wait post in the request. Once the request -/// is fulfilled, the server thread will do a post signalling the request is -/// is processed. +Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { + auto connection_id = rq->connection_id(); + auto shared_pool = comm_layer_->GetSharedMemoryPool(); + auto *base = shared_pool->SharedMemoryBaseAddr(); + // Ensure we got 3 pieces of data coming in + CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data"); + // First piece of data is the cookie and is required + auto &cookie = rq->buf_data(0); + // Second piece of data is the address where we can find the serialized data + auto addr = strtoll(rq->buf_data(1).data(), nullptr, 10); + auto p = reinterpret_cast(reinterpret_cast(base) + addr); + // Third piece of data is the size of the serialized data that we need to transfer + auto sz = strtoll(rq->buf_data(2).data(), nullptr, 10); + // Successful or not, we need to free the memory on exit. + Status rc; + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; + rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + // Only if the cookie matches, we can accept insert into this cache that has a build phase + if (!cs->HasBuildPhase() || cookie == cs->cookie()) { + row_id_type id = -1; + ReadableSlice src(p, sz); + rc = cs->FastCacheRow(src, &id); + reply->set_result(std::to_string(id)); + } else { + rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + // Return the block to the shared memory. + shared_pool->Deallocate(p); + return rc; +} + +Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) { + auto connection_id = rq->connection_id(); + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id"); + auto &row_id_buf = rq->buf_data(0); + auto p = flatbuffers::GetRoot(row_id_buf.data()); + std::vector row_id; + auto sz = p->row_id()->size(); + row_id.reserve(sz); + for (auto i = 0; i < sz; ++i) { + row_id.push_back(p->row_id()->Get(i)); + } + int64_t mem_sz = 0; + std::vector v; + RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz)); + auto client_flag = rq->flag(); + bool local_client = BitTest(client_flag, kLocalClientSupport); + // For large amount data to be sent back, we will use shared memory provided it is a local + // client that has local bypass support + bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false; + reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0); + if (local_bypass) { + // We will use shared memory + auto shared_pool = comm_layer_->GetSharedMemoryPool(); + auto *base = shared_pool->SharedMemoryBaseAddr(); + void *q = nullptr; + RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); + WritableSlice dest(q, mem_sz); + RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); + // We can't return the absolute address which makes no sense to the client. + // Instead we return the difference. + auto difference = reinterpret_cast(q) - reinterpret_cast(base); + reply->set_result(std::to_string(difference)); + } else { + // We are going to use std::string to allocate and hold the result which will be eventually + // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose + // to minimize memory copy. + std::string mem; + try { + mem.resize(mem_sz); + CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error"); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } + WritableSlice dest(mem.data(), mem_sz); + RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); + reply->set_result(std::move(mem)); + } + } + return Status::OK(); +} + +inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { + auto connection_id = rq->connection_id(); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + CacheService::ServiceStat svc_stat; + RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); + flatbuffers::FlatBufferBuilder fbb; + ServiceStatMsgBuilder bld(fbb); + bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); + bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); + bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); + bld.add_max_row_id(svc_stat.max_); + bld.add_min_row_id(svc_stat.min_); + bld.add_state(svc_stat.state_); + auto offset = bld.Finish(); + fbb.Finish(offset); + reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); + } + return Status::OK(); +} + +inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { + auto connection_id = rq->connection_id(); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information"); + auto &create_schema_buf = rq->buf_data(0); + RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size())); + } + return Status::OK(); +} + +inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) { + auto connection_id = rq->connection_id(); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + // We are going to use std::string to allocate and hold the result which will be eventually + // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose + // to minimize memory copy. + std::string mem; + RETURN_IF_NOT_OK(cs->FetchSchema(&mem)); + reply->set_result(std::move(mem)); + } + return Status::OK(); +} + +inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { + auto connection_id = rq->connection_id(); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + // First piece of data is the cookie + CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); + auto &cookie = rq->buf_data(0); + // We can only allow to switch phase is the cookie match. + if (cookie == cs->cookie()) { + RETURN_IF_NOT_OK(cs->BuildPhaseDone()); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + return Status::OK(); +} + +Status CacheServer::PurgeCache(CacheService *cs) { + SharedLock lck(&rwLock_); + // If shutdown in progress, ignore the command. + if (global_shutdown_) { + return Status::OK(); + } + // it is already purged. Ignore it. + if (cs != nullptr) { + RETURN_IF_NOT_OK(cs->Purge()); + } + return Status::OK(); +} + +inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) { + reply->set_result(std::to_string(session_id)); + return Status::OK(); +} + +/// \brief This is the main loop the cache server thread(s) are running. +/// Each thread will pop a request and send the result back to the client using grpc /// \return -Status CacheServer::ServerRequest() { +Status CacheServer::ServerRequest(int32_t worker_id) { TaskManager::FindMe()->Post(); - // Loop forever until we are interrupted. - while (true) { - BaseRequest *base_rq = nullptr; - RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); - auto cs = GetService(base_rq->connection_id_); + auto &my_que = cache_q_->operator[](worker_id); + // Loop forever until we are interrupted or shutdown. + while (!global_shutdown_) { + CacheServerRequest *cache_req = nullptr; + RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); + auto &rq = cache_req->rq_; + auto &reply = cache_req->reply_; + CacheService *cs = nullptr; + // Request comes in roughly two sets. One set is at the cache level with a connection id. + // The other set is working at a high level and without a connection id + if (!rq.has_connection_info()) { + cs = GetService(rq.connection_id()); + } // Except for creating a new session, we expect cs is not null. - switch (base_rq->type_) { + switch (cache_req->type_) { case BaseRequest::RequestType::kCacheRow: { - if (cs == nullptr) { - std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + // Look into the flag to see where we can find the data and + // call the appropriate method. + auto flag = rq.flag(); + if (BitTest(flag, kDataIsInSharedMemory)) { + cache_req->rc_ = FastCacheRow(cs, &rq, &reply); } else { - auto *rq = reinterpret_cast(base_rq); - // Only if the cookie matches, we can accept insert into this cache that has a build phase - if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { - rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); - } + cache_req->rc_ = CacheRow(cs, &rq, &reply); } break; } case BaseRequest::RequestType::kBatchFetchRows: { - if (cs == nullptr) { - std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); - } + cache_req->rc_ = BatchFetchRows(cs, &rq, &reply); break; } case BaseRequest::RequestType::kCreateCache: { - // If the cache is already created we still need to run the creation so that we do sanity checks on the - // client id and return the cache id back to the user. - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); + cache_req->rc_ = CreateService(&rq, &reply); break; } case BaseRequest::RequestType::kPurgeCache: { - if (cs != nullptr) { - base_rq->rc_ = cs->Purge(); - } else { - // it is already purged. Ignore it. - base_rq->rc_ = Status::OK(); - } + cache_req->rc_ = PurgeCache(cs); break; } case BaseRequest::RequestType::kDestroyCache: { - if (cs != nullptr) { - // We need a strong lock to protect the map. - connection_id_type id = base_rq->connection_id_; - UniqueLock lck(&rwLock_); - // std::map will invoke the constructor of CacheService. So we don't need to do anything here. - auto n = all_caches_.erase(id); - if (n == 0) { - // It has been destroyed by another duplicate request. - MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; - } - base_rq->rc_ = Status::OK(); - } else { - // it is already destroyed. Ignore it. - base_rq->rc_ = Status::OK(); - } + cache_req->rc_ = DestroyCache(cs, &rq); break; } case BaseRequest::RequestType::kGetStat: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - CacheService::ServiceStat svc_stat; - rq->rc_ = cs->GetStat(&svc_stat); - if (rq->rc_.IsOk()) { - flatbuffers::FlatBufferBuilder fbb; - ServiceStatMsgBuilder bld(fbb); - bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); - bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); - bld.add_max_row_id(svc_stat.max_); - bld.add_min_row_id(svc_stat.min_); - bld.add_state(svc_stat.state_); - auto offset = bld.Finish(); - fbb.Finish(offset); - rq->rc_ = rq->mem_.allocate(fbb.GetSize()); - if (rq->rc_.IsOk()) { - WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); - ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); - RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); - } - } - } + cache_req->rc_ = GetStat(cs, &rq, &reply); break; } case BaseRequest::RequestType::kCacheSchema: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); - } + cache_req->rc_ = CacheSchema(cs, &rq); break; } case BaseRequest::RequestType::kFetchSchema: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = cs->FetchSchema(&rq->mem_); - } + cache_req->rc_ = FetchSchema(cs, &rq, &reply); break; } case BaseRequest::RequestType::kBuildPhaseDone: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - // We can only allow to switch phase is the cookie match. - if (rq->cookie_ == cs->cookie()) { - rq->rc_ = cs->BuildPhaseDone(); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); - } - } + cache_req->rc_ = BuildPhaseDone(cs, &rq); + break; + } + case BaseRequest::RequestType::kDropSession: { + cache_req->rc_ = DestroySession(&rq); + break; + } + case BaseRequest::RequestType::kGenerateSessionId: { + cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); + break; + } + case BaseRequest::RequestType::kAllocateSharedBlock: { + cache_req->rc_ = AllocateSharedMemory(&rq, &reply); + break; + } + case BaseRequest::RequestType::kFreeSharedBlock: { + cache_req->rc_ = FreeSharedMemory(&rq); + break; + } + case BaseRequest::RequestType::kStopService: { + // This command shutdowns everything. + cache_req->rc_ = GlobalShutdown(); break; } default: - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); + std::string errMsg("Unknown request type : "); + errMsg += std::to_string(static_cast(cache_req->type_)); + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); } // Notify it is done, and move on to the next request. - base_rq->wp_.Set(); + Status2CacheReply(cache_req->rc_, &reply); + cache_req->st_ = CacheServerRequest::STATE::FINISH; + // We will re-tag the request back to the grpc queue. Once it comes back from the client, + // the CacheServerRequest, i.e. the pointer cache_req, will be free + if (!global_shutdown_) { + cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); + } + } + return Status::OK(); +} + +connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const { + connection_id_type connection_id = + (static_cast(session_id) << 32u) | static_cast(crc); + return connection_id; +} + +session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const { + return static_cast(connection_id >> 32u); +} + +CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, + int32_t shared_meory_sz_in_gb) + : top_(spill_path), + num_workers_(num_workers), + port_(port), + shared_memory_sz_in_gb_(shared_meory_sz_in_gb), + global_shutdown_(false) {} + +Status CacheServer::Run() { + RETURN_IF_NOT_OK(ServiceStart()); + // This is called by the main function and we shouldn't exit. Otherwise the main thread + // will just shutdown. So we will call some function that never return unless error. + // One good case will be simply to wait for all threads to return. + RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking)); + return Status::OK(); +} + +Status CacheServer::GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q) { + RETURN_UNEXPECTED_IF_NULL(q); + CacheServer &cs = CacheServer::GetInstance(); + CacheServerRequest *p; + RETURN_IF_NOT_OK(cs.free_list_->operator[](queue_id)->PopFront(&p)); + *q = p; + return Status::OK(); +} + +Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { + RETURN_UNEXPECTED_IF_NULL(p); + int32_t myQID = p->getQid(); + // Free any memory from the protobufs + p->~CacheServerRequest(); + // Re-initialize the memory + new (p) CacheServerRequest(myQID); + // Now we return it back to free list. + CacheServer &cs = CacheServer::GetInstance(); + RETURN_IF_NOT_OK(cs.free_list_->operator[](myQID)->Add(p)); + return Status::OK(); +} + +Status CacheServer::DestroySession(CacheRequest *rq) { + CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); + auto drop_session_id = rq->connection_info().session_id(); + UniqueLock lck(&rwLock_); + for (auto &cs : all_caches_) { + auto connection_id = cs.first; + auto session_id = GetSessionID(connection_id); + // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. + // So we will just manually do it. + if (session_id == drop_session_id) { + // std::map will invoke the destructor of CacheService. So we don't need to do anything here. + auto n = all_caches_.erase(connection_id); + MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id; + } + } + return Status::OK(); +} + +session_id_type CacheServer::GenerateSessionID() const { + SharedLock lock(&rwLock_); + auto mt = GetRandomDevice(); + std::uniform_int_distribution distribution(0, std::numeric_limits::max()); + session_id_type session_id; + bool duplicate = false; + do { + session_id = distribution(mt); + auto it = history_sessions_.find(session_id); + duplicate = (it != history_sessions_.end()); + } while (duplicate); + return session_id; +} + +Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) { + auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); + auto shared_pool = comm_layer_->GetSharedMemoryPool(); + auto *base = shared_pool->SharedMemoryBaseAddr(); + void *p = nullptr; + RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p)); + // We can't return the absolute address which makes no sense to the client. + // Instead we return the difference. + auto difference = reinterpret_cast(p) - reinterpret_cast(base); + reply->set_result(std::to_string(difference)); + return Status::OK(); +} + +Status CacheServer::FreeSharedMemory(CacheRequest *rq) { + auto shared_pool = comm_layer_->GetSharedMemoryPool(); + auto *base = shared_pool->SharedMemoryBaseAddr(); + auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); + auto p = reinterpret_cast(reinterpret_cast(base) + addr); + shared_pool->Deallocate(p); + return Status::OK(); +} + +Status CacheServer::RpcRequest(int32_t worker_id) { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); + return Status::OK(); +} + +Status CacheServer::GlobalShutdown() { + // Let's shutdown in proper order. + bool expected = false; + if (global_shutdown_.compare_exchange_strong(expected, true)) { + MS_LOG(WARNING) << "Shutting down server."; + // Shutdown the grpc queue. No longer accept any new comer. + // The threads we spawn to work on the grpc queue will exit themselves once + // they notice the queue has been shutdown. + comm_layer_->Shutdown(); + // Now we interrupt any threads that are waiting on cache_q_ + vg_.interrupt_all(); + // The next thing to do drop all the caches. + UniqueLock lck(&rwLock_); + for (auto &it : all_caches_) { + auto id = it.first; + MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); + // Wait for all outstanding work to be finished. + auto &cs = it.second; + UniqueLock cs_lock(&cs->rw_lock_); + // std::map will invoke the destructor of CacheService. So we don't need to do anything here. + (void)all_caches_.erase(id); + } + } + return Status::OK(); +} + +Status CacheServer::Builder::SanityCheck() { + if (shared_memory_sz_in_gb_ <= 0) { + RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive"); + } + if (num_workers_ <= 0) { + RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive"); + } + if (!top_.empty()) { + auto p = top_.data(); + if (p[0] != '/') { + RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path"); + } + // Check if the spill directory is writable + Path spill(top_); + auto t = spill / Services::GetUniqueID(); + Status rc = t.CreateDirectory(); + if (rc.IsOk()) { + rc = t.Remove(); + } + if (rc.IsError()) { + RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString()); + } } return Status::OK(); } -CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) - : top_(spill_path), num_workers_(num_workers) {} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index c0dc8c467b..1ec5414d0c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -24,8 +24,11 @@ #include #include #include +#include #include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/engine/cache/cache_grpc_server.h" #include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/arena.h" #include "minddata/dataset/util/cache_pool.h" #include "minddata/dataset/util/lock.h" @@ -37,43 +40,131 @@ namespace mindspore { namespace dataset { -class BaseRequest; /// \brief A server which provides CacheService services. class CacheServer : public Service { public: friend class Services; using cache_index = std::map>; + class Builder { + public: + Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {} + /// \brief Getter functions + const std::string &getTop() const { return top_; } + int32_t getNumWorkers() const { return num_workers_; } + int32_t getPort() const { return port_; } + int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } + + Builder &SetRootDirectory(std::string root) { + top_ = std::move(root); + return *this; + } + Builder &SetNumWorkers(int32_t n) { + num_workers_ = n; + return *this; + } + Builder &SetPort(int32_t p) { + port_ = p; + return *this; + } + Builder &SetSharedMemorySizeInGB(int32_t sz) { + shared_memory_sz_in_gb_ = sz; + return *this; + } + + Status SanityCheck(); + + void Print(std::ostream &out) const { + out << "Summary of the cache server configuration\n" + << "Spill directory: " << getTop() << "\n" + << "Number of parallel workers: " << getNumWorkers() << "\n" + << "Tcp/ip port: " << getPort() << "\n" + << "Shared memory size (in GB): " << getSharedMemorySzInGb(); + } + + friend std::ostream &operator<<(std::ostream &out, const Builder &bld) { + bld.Print(out); + return out; + } + + Status Build() { + RETURN_IF_NOT_OK(SanityCheck()); + // We need to bring up the Task Manager by bringing up the Services singleton. + RETURN_IF_NOT_OK(Services::CreateInstance()); + RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_)); + return Status::OK(); + } + + private: + std::string top_; + int32_t num_workers_; + int32_t port_; + int32_t shared_memory_sz_in_gb_; + }; CacheServer(const CacheServer &) = delete; CacheServer &operator=(const CacheServer &) = delete; CacheServer(CacheServer &&) = delete; CacheServer &operator=(CacheServer &) = delete; - static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } Status DoServiceStart() override; Status DoServiceStop() override; ~CacheServer() { (void)ServiceStop(); } + static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port, + int32_t shared_memory_sz) { + std::call_once(init_instance_flag_, [&]() -> Status { + auto &svcManager = Services::GetInstance(); + RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz)); + return Status::OK(); + }); + return Status::OK(); + } + + static CacheServer &GetInstance() { return *instance_; } + /// \brief For the current demonstration, a cache client contacts cache server using a Queue. /// \param rq /// \return Status object - Status PushRequest(BaseRequest *rq) { + Status PushRequest(int32_t queue_id, CacheServerRequest *rq) { RETURN_UNEXPECTED_IF_NULL(rq); - RETURN_IF_NOT_OK(cache_q_->Add(rq)); + RETURN_IF_NOT_OK(cache_q_->operator[](queue_id)->Add(rq)); return Status::OK(); } + /// \\brief Kick off server threads. Never return unless error out. + Status Run(); + + /// \brief Get a free tag + /// \param q[in] pointer to a pointer to a CacheServerRequest + /// \return Status object + static Status GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q); + + /// \brief Return a tag to the free list + /// \param p[in] pointer to already finished CacheServerRequest tag + /// \return Status object + static Status ReturnRequestTag(CacheServerRequest *p); + private: + static std::once_flag init_instance_flag_; + static CacheServer *instance_; mutable RWLock rwLock_; std::string top_; cache_index all_caches_; - std::shared_ptr> cache_q_; + std::set history_sessions_; + std::shared_ptr> cache_q_; + std::shared_ptr> free_list_; + std::vector>>> tag_; + std::shared_ptr comm_layer_; + std::shared_ptr mp_; TaskGroup vg_; int32_t num_workers_; + int32_t port_; + int32_t shared_memory_sz_in_gb_; + std::atomic global_shutdown_; /// \brief Constructor /// \param spill_path Top directory for spilling buffers to. /// \param num_workers Number of threads for handling requests. - explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); + explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb); /// \brief Locate a cache service from connection id. /// \return Pointer to cache service. Null if not found @@ -82,16 +173,65 @@ class CacheServer : public Service { /// \brief Create a cache service. We allow multiple clients to create the same cache service. /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given /// a special unique cookie. - /// \param[in] connection_id This is from a Cache client. - /// \param[in] cache_mem_sz - /// \param[in] flag - /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator /// \return Status object - Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, - std::string *out_cookie); + Status CreateService(CacheRequest *rq, CacheReply *reply); + + /// \brief Destroy a cache service + /// \param cs + /// \param rq + /// \return + Status DestroyCache(CacheService *cs, CacheRequest *rq); + Status PurgeCache(CacheService *cs); + + /// \brief Entry point for all internal server threads. + Status ServerRequest(int32_t worker_id); + + /// \brief Entry point for all grpc threads. + /// \return + Status RpcRequest(int32_t worker_id); + + Status DestroySession(CacheRequest *rq); + + /// \brief Create a connection id from a session id and a crc + /// \param session_id + /// \param crc + /// \return connection id + connection_id_type GetConnectionID(session_id_type session_id, uint32_t crc) const; + + /// \brief Extract the session id from a connection id + /// \param connection_id + /// \return session id + session_id_type GetSessionID(connection_id_type connection_id) const; + + /// \brief Generate a session ID for the client + /// \return Session ID + session_id_type GenerateSessionID() const; + + /// \brief Handle kAllocateSharedBlock request + /// \param rq CacheRequest + /// \param reply CacheReply + /// \return Status object + Status AllocateSharedMemory(CacheRequest *rq, CacheReply *reply); + + /// \brief Handle kFreeSharedBlock request + /// \param rq + /// \return Status object + Status FreeSharedMemory(CacheRequest *rq); - /// \brief Entry point for all server threads. - Status ServerRequest(); + /// \brief Handle kFastCacheRow request + /// \return Status object + Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply); + + /// \brief Internal function to do row batch fetch + /// \param cs CacheService + /// \param rq Request + /// \param reply Reply + /// \return + Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply); + + /// \brief A proper shutdown of the server + /// \return Status object + Status GlobalShutdown(); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc index 4e1208d173..ee6e835dc6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -76,7 +76,7 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated = GetNextRowId(); // Some debug information on how many rows we have generated so far. if ((*row_id_generated) % 1000 == 0) { - MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; + MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; } } else { if (msg->row_id() < 0) { @@ -114,6 +114,45 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type RETURN_STATUS_UNEXPECTED(e.what()); } } + +Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { + SharedLock rw(&rw_lock_); + RETURN_UNEXPECTED_IF_NULL(row_id_generated); + if (st_ == State::kFetchPhase) { + // For this kind of cache service, once we are done with the build phase into fetch phase, we can't + // allow other to cache more rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + try { + // If we don't need to generate id, we need to find it from the buffer. + if (generate_id_) { + *row_id_generated = GetNextRowId(); + // Some debug information on how many rows we have generated so far. + if ((*row_id_generated) % 1000 == 0) { + MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; + } + } else { + auto msg = GetTensorRowHeaderMsg(src.GetPointer()); + if (msg->row_id() < 0) { + std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + *row_id_generated = msg->row_id(); + } + // Now we cache the flat buffer. + CachePool::key_type key; + RETURN_IF_NOT_OK(cp_->Insert({src}, &key)); + Status rc = map_->DoInsert(*row_id_generated, key); + if (rc == Status(StatusCode::kDuplicateKey)) { + MS_LOG(DEBUG) << "Ignoring duplicate key."; + } else { + RETURN_IF_NOT_OK(rc); + } + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } +} std::ostream &operator<<(std::ostream &out, const CacheService &cs) { // Then show any custom derived-internal stuff out << "\nCache memory size: " << cs.cache_mem_sz_; @@ -155,20 +194,15 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { } return Status::OK(); } -Status CacheService::BatchFetch(const std::vector &v, MemGuard *out) const { - RETURN_UNEXPECTED_IF_NULL(out); + +Status CacheService::PreBatchFetch(const std::vector &v, std::vector *out, + int64_t *mem_sz) { SharedLock rw(&rw_lock_); - if (st_ == State::kBuildPhase) { - // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); - } + RETURN_UNEXPECTED_IF_NULL(out); + RETURN_UNEXPECTED_IF_NULL(mem_sz); const auto num_elements = v.size(); - int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); - int64_t data_offset = mem_sz; - std::vector sz_v; - std::vector keys; - sz_v.reserve(num_elements); - keys.reserve(num_elements); + *mem_sz = (num_elements + 1) * sizeof(int64_t); + (*out).reserve(num_elements); for (auto row_id : v) { auto r = map_->Search(row_id); if (r.second) { @@ -180,25 +214,33 @@ Status CacheService::BatchFetch(const std::vector &v, MemGuard mem; - RETURN_IF_NOT_OK(mem.allocate(mem_sz)); - auto *offset_array = reinterpret_cast(mem.GetMutablePointer()); + return Status::OK(); +} + +Status CacheService::BatchFetch(const std::vector &v, const std::vector &info, + WritableSlice *out) const { + RETURN_UNEXPECTED_IF_NULL(out); + SharedLock rw(&rw_lock_); + if (st_ == State::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + const auto num_elements = v.size(); + int64_t data_offset = (num_elements + 1) * sizeof(int64_t); + auto *offset_array = reinterpret_cast(out->GetMutablePointer()); offset_array[0] = data_offset; - WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); for (auto i = 0; i < num_elements; ++i) { - auto sz = sz_v.at(i); + auto sz = info.at(i).second; offset_array[i + 1] = offset_array[i] + sz; if (sz > 0) { - WritableSlice row_data(all, offset_array[i], sz); - auto key = keys.at(i); + WritableSlice row_data(*out, offset_array[i], sz); + auto key = info.at(i).first; size_t bytesRead = 0; RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); if (bytesRead != sz) { @@ -208,7 +250,6 @@ Status CacheService::BatchFetch(const std::vector &v, MemGuard *out) const { +Status CacheService::FetchSchema(std::string *out) const { SharedLock rw(&rw_lock_); if (st_ == State::kBuildPhase) { // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); } RETURN_UNEXPECTED_IF_NULL(out); - MemGuard mem; + // We are going to use std::string to allocate and hold the result which will be eventually + // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose + // to minimize memory copy. + std::string mem; if (schema_key_ >= 0) { auto len = cp_->GetSize(schema_key_); - RETURN_IF_NOT_OK(mem.allocate(len)); - auto slice = WritableSlice(mem.GetMutablePointer(), len); + try { + mem.resize(len); + CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error"); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } + auto slice = WritableSlice(mem.data(), len); RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); *out = std::move(mem); } else { diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h index f4bd13e6ad..824d24975f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -28,7 +28,6 @@ #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/engine/cache/cache_request.h" -#include "minddata/dataset/engine/cache/de_tensor_generated.h" #include "minddata/dataset/util/arena.h" #include "minddata/dataset/util/btree.h" #include "minddata/dataset/util/cache_pool.h" @@ -38,7 +37,8 @@ namespace mindspore { namespace dataset { -struct CacheStat; +/// Some typedef used for BatchFetch +using key_size_pair = std::pair; /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is /// created to support spilling class CacheService : public Service { @@ -69,12 +69,26 @@ class CacheService : public Service { /// \param[out] row_id_generated The row id assigned to this row if any /// \return Status object Status CacheRow(const std::vector &buf, row_id_type *row_id_generated); + + /// \brief A fast version of CacheRow where all the data is already in one contiguous piece. + /// \param src Slice of the data + /// \param row_id_generated + /// \return Status object + Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated); + + /// \brief This function is used in preparation for batch fetching. + /// It calculates how much memory we should allocate and which row id are present. + /// \param[in/out] Pointer to vector of + /// \param[in/out] mem_sz how much memory is required to batch fetch + /// \return Status object + Status PreBatchFetch(const std::vector &v, std::vector *, int64_t *mem_sz); + /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. /// \param[in] v A vector of row id. /// \param[out] out A contiguous memory buffer that holds the requested rows. /// \return Status object - Status BatchFetch(const std::vector &v, MemGuard *out) const; + Status BatchFetch(const std::vector &v, const std::vector &, WritableSlice *out) const; /// \brief Getter function /// \return Spilling path @@ -102,7 +116,7 @@ class CacheService : public Service { /// \brief Fetch schema /// \param out A contiguous memory that contains the serialized form of schema. /// \return Status object - Status FetchSchema(MemGuard *out) const; + Status FetchSchema(std::string *out) const; /// \brief Purge the content of a cache /// \return Status object Status Purge(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs index de26069f23..5d24995ed1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs @@ -60,10 +60,11 @@ table TensorRowIds { } /// Statistics returned from each cache service -/// \note It must match CacheService::ServiceStat +/// \note It must match CacheServiceStat table ServiceStatMsg { num_mem_cached:int64; num_disk_cached:int64; + avg_cache_sz:int64; min_row_id:int64; max_row_id:int64; state:int8; @@ -79,3 +80,15 @@ table ColumnNameMsg { table SchemaMsg { column:[ColumnNameMsg]; } + +/// Part of the CreateCacheRequest +table CreateCacheRequestMsg { + cache_mem_sz:int64; + flag:uint32; +} + +/// Return result of CreateCacheRequest +table CreateCacheReplyMsg { + connection_id:int64; + cookie:string; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h new file mode 100644 index 0000000000..2ad6cd045b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ + +#include +#include +#include "proto/cache_grpc.pb.h" +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/util/service.h" + +namespace mindspore { +namespace dataset { +class CacheClientGreeter : public Service { + public: + explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) {} + ~CacheClientGreeter() override {} + Status DoServiceStart() override { RETURN_STATUS_UNEXPECTED("Not supported"); } + Status DoServiceStop() override { RETURN_STATUS_UNEXPECTED("Not supported"); } + + void *SharedMemoryBaseAddr() { return nullptr; } + Status HandleRequest(std::shared_ptr rq) { RETURN_STATUS_UNEXPECTED("Not supported"); } + Status AttachToSharedMemory(int32_t port, bool *local_bypass) { RETURN_STATUS_UNEXPECTED("Not supported"); } + + protected: + private: +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc index 8e3a291d72..554eb2b19b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -16,6 +16,7 @@ #include "minddata/dataset/engine/datasetops/cache_base_op.h" #include #include +#include #include "minddata/dataset/engine/execution_tree.h" namespace mindspore { @@ -47,22 +48,39 @@ Status CacheBase::Reset() { } CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, std::shared_ptr cache_client, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, sampler), - cache_client_(cache_client), + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + row_cnt_(0), + num_cache_miss_(0), + cache_client_(std::move(cache_client)), rows_per_buffer_(rows_per_buf), // We can cause deadlock if this internal Connector size is too small. - keys_miss_(num_workers_, 1, connector_capacity_) { + keys_miss_(num_workers_, 1, connector_capacity_), + prefetch_size_(cache_client_->getPrefetchSize()) { io_block_queues_.Init(num_workers, op_connector_size); + prefetch_queues_.Init(num_workers, op_connector_size); + sampler_queue_ = std::make_unique>>(op_connector_size); } // Common function to fetch samples from the sampler and send them using the io_block_queues to // the parallel workers Status CacheBase::FetchSamplesToWorkers() { int64_t buf_cnt = 0; int64_t wait_cnt = 0; + // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ + // is too small (1 by default) and won't help performance. + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this))); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); + // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them + // to the WorkerEntry. do { epoch_sync_.Clear(); + if (AllowCacheMiss() && wait_cnt > 0) { + MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ + << " Total number of rows : " << row_cnt_; + } + num_cache_miss_ = 0; + row_cnt_ = 0; + ++wait_cnt; std::vector keys; - int64_t row_cnt = 0; keys.reserve(rows_per_buffer_); std::unique_ptr sampler_buffer; RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); @@ -70,10 +88,13 @@ Status CacheBase::FetchSamplesToWorkers() { TensorRow sample_row; RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); std::shared_ptr sample_ids = sample_row[0]; + // Send the sampler tensor to other thread for prefetching. We are using shared pointer so it + // won't go out scope until it is really not in use. + RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids)); for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { keys.push_back(*itr); - ++row_cnt; - if (row_cnt % rows_per_buffer_ == 0) { + ++row_cnt_; + if (row_cnt_ % rows_per_buffer_ == 0) { auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); keys.clear(); @@ -90,7 +111,7 @@ Status CacheBase::FetchSamplesToWorkers() { io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); // If repeat but the not last repeat, wait for reset. if (!IsLastIteration()) { - MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; + MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; RETURN_IF_NOT_OK(epoch_sync_.Wait()); } else { // We can break out from the loop. @@ -101,13 +122,21 @@ Status CacheBase::FetchSamplesToWorkers() { // Flow the eof before exit RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - // Ask all the workers to quit. + // Shutdown threads + std::shared_ptr empty; + RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty))); for (int32_t i = 0; i < num_workers_; i++) { RETURN_IF_NOT_OK( io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); } + // Dump the last epoch result (approximately) without waiting for the worker threads to come back. + if (AllowCacheMiss()) { + MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ + << " Total number of rows : " << row_cnt_; + } return Status::OK(); } + Status CacheBase::FetchFromCache(int32_t worker_id) { int64_t buffer_id = worker_id; std::unique_ptr blk; @@ -133,23 +162,16 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { } std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); std::unique_ptr que = std::make_unique(); - TensorTable ttbl; - RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); - auto row_it = ttbl.begin(); std::vector cache_miss; cache_miss.reserve(keys.size()); for (auto row_id : keys) { - auto &row = *row_it; + TensorRow row; + // Block until the row shows up in the pool. + RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row)); if (row.empty()) { - if (AllowCacheMiss()) { - cache_miss.push_back(row_id); - } else { - std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; - RETURN_STATUS_UNEXPECTED(errMsg); - } + cache_miss.push_back(row_id); } que->push_back(std::move(row)); - ++row_it; } db->set_tensor_table(std::move(que)); if (AllowCacheMiss()) { @@ -162,12 +184,17 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { } while (true); return Status::OK(); } + Status CacheBase::RegisterResources() { RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); return Status::OK(); } -CacheBase::~CacheBase() {} + +CacheBase::~CacheBase() = default; + Status CacheBase::UpdateColumnMapFromCache() { Status rc; // Get the schema from the server. It may not be there yet. So tolerate the error. @@ -180,5 +207,77 @@ Status CacheBase::UpdateColumnMapFromCache() { } return rc; } + +Status CacheBase::Dispatcher() { + TaskManager::FindMe()->Post(); + int64_t buf_cnt = 0; + int64_t num_row = 0; + std::vector keys; + keys.reserve(prefetch_size_); + do { + keys.clear(); + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids)); + if (sample_ids == nullptr) { + // A null shared pointer signal times to quit. + // Also signal all prefetchers to quit. + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + prefetch_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + break; + } + // Now we distribute the sampler ids to each prefetcher according to the prefetch size. + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { + keys.push_back(*itr); + ++num_row; + if (num_row % prefetch_size_ == 0) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + keys.clear(); + } + } + // Send the remaining sample id + if (!keys.empty()) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + } + } while (true); + return Status::OK(); +} + +Status CacheBase::Prefetcher(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::vector prefetch_keys; + prefetch_keys.reserve(prefetch_size_); + do { + prefetch_keys.clear(); + std::unique_ptr blk; + RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk)); + RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); + if (prefetch_keys.empty()) { + // Empty keys mean time to quit. + break; + } + TensorTable ttbl; + RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); + auto row_it = ttbl.begin(); + for (auto row_id : prefetch_keys) { + auto &row = *row_it; + if (row.empty()) { + if (AllowCacheMiss()) { + ++num_cache_miss_; + } else { + std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; + RETURN_STATUS_UNEXPECTED(errMsg); + } + } + // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row + RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); + ++row_it; + } + } while (true); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h index 40f3426394..2225d4f335 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -16,6 +16,8 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ +#include +#include #include #include #include @@ -28,8 +30,9 @@ #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/queue_map.h" +#include "minddata/dataset/util/semaphore.h" #include "minddata/dataset/util/wait_post.h" -#include "minddata/dataset/engine/datasetops/cache_base_op.h" namespace mindspore { namespace dataset { /// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. @@ -82,10 +85,13 @@ class CacheBase : public ParallelOp { protected: constexpr static int32_t eoe_row_id = -1; + int64_t row_cnt_; + std::atomic num_cache_miss_; std::shared_ptr cache_client_; WaitPost epoch_sync_; int32_t rows_per_buffer_; Connector> keys_miss_; + QueueMap prefetch_; /// \brief Common function to register resources for interrupt /// \note Derived should override this function for extra resources to be registered @@ -103,7 +109,15 @@ class CacheBase : public ParallelOp { private: constexpr static int32_t connector_capacity_ = 1024; + int32_t prefetch_size_; QueueList> io_block_queues_; + QueueList> prefetch_queues_; + std::unique_ptr>> sampler_queue_; + + Status Dispatcher(); + /// \brief Prefetcher. It prefetch the rows from cache server + /// \return Status object. + Status Prefetcher(int32_t worker_id); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index b9be973d9c..aa6c93ba6e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -16,8 +16,10 @@ #include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include +#include #include #include +#include #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/global_context.h" @@ -41,9 +43,13 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const { out << "\n\n"; } } + CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, std::shared_ptr cache_client, const std::shared_ptr &sampler) - : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} + : ParallelOp(numWorkers, opConnectorSize, sampler), + num_cleaners_(numCleaners), + cache_client_(std::move(cache_client)) {} + Status CacheMergeOp::operator()() { // A queue of row id to let cleaner send cache miss rows to the cache server // We don't want a small queue as this will block the parallel op workers. @@ -62,6 +68,7 @@ Status CacheMergeOp::operator()() { TaskManager::FindMe()->Post(); return Status::OK(); } + // Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait // until it shows up in the pool. Status CacheMergeOp::WorkerEntry(int32_t worker_id) { @@ -82,10 +89,8 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); if (row.empty()) { auto row_id = row.getId(); - TensorRowRequest *rq = nullptr; - RETURN_IF_NOT_OK(GetRq(row_id, &rq)); // Block until the row shows up in the pool. - RETURN_IF_NOT_OK(rq->Wait(&row)); + RETURN_IF_NOT_OK(cache_miss_.PopFront(row_id, &row)); } tbl->push_back(std::move(row)); } @@ -97,6 +102,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { RETURN_IF_NOT_OK(EofReceived(worker_id)); return Status::OK(); } + Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { TaskManager::FindMe()->Post(); // We will simply pop TensorRow from the stream and insert them into the pool and @@ -123,17 +129,27 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { std::string errMsg = "Expect positive row id: " + std::to_string(row_id); RETURN_STATUS_UNEXPECTED(errMsg); } - TensorRowRequest *rq = nullptr; + // Technically number of this row shows up in the cache miss stream is equal to the number + // of P() call. However the cleaner wants it too. So we need an extra copy. + TensorRowCacheRequest *rq; RETURN_IF_NOT_OK(GetRq(row_id, &rq)); - rq->WakeUpAny(std::move(row)); - // Let the cleaner to flush out this row (async) to the cache server. - RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { + // We will send the request async. But any error we most + // likely ignore and continue. + Status rc; + rc = rq->AsyncSendCacheRequest(cache_client_, row); + if (rc.IsOk()) { + RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + } + } + RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row))); } } RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); } return Status::OK(); } + Status CacheMergeOp::Cleaner() { TaskManager::FindMe()->Post(); while (true) { @@ -142,45 +158,28 @@ Status CacheMergeOp::Cleaner() { if (row_id < 0) { break; } - TensorRowRequest *rq = nullptr; + // Locate the cache request + TensorRowCacheRequest *rq; RETURN_IF_NOT_OK(GetRq(row_id, &rq)); - if (rq->GetState() == TensorRowRequest::State::kClean) { - // If already flushed, move on to the next one. + // If already flushed, move on to the next one. + if (rq->GetState() == TensorRowCacheRequest::State::kClean) { continue; } - TensorRow row; - RETURN_IF_NOT_OK(rq->Release(&row)); - CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error."); - Status rc = cache_client_->WriteRow(row); - // Bad rc should not bring down the pipeline + Status rc = rq->CheckCacheResult(); if (rc.IsError()) { - MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); + // If interrupt, time to quit. + if (rc.get_code() == StatusCode::kInterrupted) { + return Status::OK(); + } + MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); + // Bad rc should not bring down the pipeline. We will simply continue and + // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. + rq->SetState(TensorRowCacheRequest::State::kEmpty); } - rq->SetState(TensorRowRequest::State::kClean); } return Status::OK(); } -Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { - RETURN_UNEXPECTED_IF_NULL(out); - std::unique_lock lck(mux_); - auto it = cache_miss_map_.find(row_id); - if (it != cache_miss_map_.end()) { - *out = it->second.GetMutablePointer(); - } else { - // We will create a new one. - auto alloc = Services::GetAllocator(); - auto r = cache_miss_map_.emplace(row_id, MemGuard>(alloc)); - if (r.second) { - auto &mem = r.first->second; - RETURN_IF_NOT_OK(mem.allocate(1, row_id)); - *out = mem.GetMutablePointer(); - } else { - RETURN_STATUS_UNEXPECTED("Map insert fail."); - } - } - return Status::OK(); -} Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own // specific logic CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); @@ -199,6 +198,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe RETURN_IF_NOT_OK(rc); return Status::OK(); } + Status CacheMergeOp::ComputeColMap() { CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); if (column_name_id_map().empty()) { @@ -207,53 +207,13 @@ Status CacheMergeOp::ComputeColMap() { CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); return Status::OK(); } -Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { - RETURN_UNEXPECTED_IF_NULL(out); - // Block until the missing row is in the pool. - RETURN_IF_NOT_OK(use_count_.P()); - std::unique_lock lck(dq_mux_); - CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); - *out = std::move(row_.front()); - row_.pop_front(); - return Status::OK(); -} -void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { - std::unique_lock lck(dq_mux_); - // Technically number of this row shows up in the cache miss stream is equal to the number - // of P() call. However the cleaner wants it too. So we need an extra copy. - if (GetState() == State::kEmpty) { - // We will do a deep copy - for (auto &ts : row) { - std::shared_ptr out_ts; - Tensor::CreateFromTensor(ts, &out_ts); - cleaner_copy_.push_back(out_ts); - } - cleaner_copy_.setId(row.getId()); - // Change the state to dirty - SetState(State::kDirty); - } - row_.push_back(std::move(row)); - // Bump up the use count by 1. This wake up any parallel worker which is waiting - // for this row. - use_count_.V(); -} -Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { - RETURN_UNEXPECTED_IF_NULL(out); - // We are not holding any mutex here because the cleaner isn't really touching the deque row_. - // In case we have multiple cleaners and they all see the copy, only one of them will - // get it. - auto expected = State::kDirty; - if (st_.compare_exchange_strong(expected, State::kClean)) { - *out = std::move(cleaner_copy_); - } - return Status::OK(); -} + // Builder constructor. Creates the builder object. CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { std::shared_ptr cfg = GlobalContext::config_manager(); build_num_workers_ = cfg->num_parallel_workers(); build_op_connector_size_ = cfg->op_connector_size(); - build_num_cleaners_ = 1; + build_num_cleaners_ = cfg->num_parallel_workers(); } // Check if the required parameters are set by the builder. @@ -311,5 +271,60 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) { MS_LOG(DEBUG) << "Cache merge sending eof"; return DatasetOp::EofReceived(worker_id); } + +Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheRequest **out) { + RETURN_UNEXPECTED_IF_NULL(out); + std::unique_lock lock(mux_); + auto it = io_request_.find(row_id); + if (it != io_request_.end()) { + *out = it->second.GetMutablePointer(); + } else { + // We will create a new one. + auto alloc = Services::GetAllocator(); + auto r = io_request_.emplace(row_id, MemGuard>(alloc)); + if (r.second) { + auto &mem = r.first->second; + RETURN_IF_NOT_OK(mem.allocate(1)); + *out = mem.GetMutablePointer(); + } else { + RETURN_STATUS_UNEXPECTED("Map insert fail."); + } + } + return Status::OK(); +} + +Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::shared_ptr &cc, + const TensorRow &row) { + auto expected = State::kEmpty; + if (st_.compare_exchange_strong(expected, State::kDirty)) { + // We will do a deep copy but write directly into CacheRequest protobuf or shared memory + Status rc; + cleaner_copy_ = + std::make_shared(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient()); + rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); + if (rc.IsOk()) { + // Send the request async. The cleaner will check the return code. + rc = cc->PushRequest(cleaner_copy_); + } + if (rc.IsError()) { + // Clean up the shared pointer and reset the state back to empty + cleaner_copy_.reset(); + st_ = State::kEmpty; + } + } + return Status::OK(); +} + +Status CacheMergeOp::TensorRowCacheRequest::CheckCacheResult() { + auto expected = State::kDirty; + if (st_.compare_exchange_strong(expected, State::kClean)) { + // Success or not, we will release the memory. + // We simply move it out of the structure and let it go out of scope. + auto cache_request = std::move(cleaner_copy_); + RETURN_IF_NOT_OK(cache_request->Wait()); + return Status::OK(); + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h index a4d92d1221..4c62af1d5c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ +#include #include #include #include @@ -28,6 +29,7 @@ #include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/dataset_iterator.h" #include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/queue_map.h" #include "minddata/dataset/util/semaphore.h" namespace mindspore { @@ -36,28 +38,34 @@ namespace dataset { /// stream class CacheMergeOp : public ParallelOp { public: - // Some handshake structures among the main thread, cleaner threads and parallel op threads. - class TensorRowRequest { + // Some handshake structures between CacheMissWorkerEntry and Cleaner + class TensorRowCacheRequest { public: enum class State : uint8_t { - kEmpty = 0, // No row in the deque + kEmpty = 0, // Initial state. Row hasn't arrived from cache miss stream yet. kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. kClean = 2 // The row has been flushed already. }; - explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} - ~TensorRowRequest() = default; + TensorRowCacheRequest() : st_(State::kEmpty) {} + ~TensorRowCacheRequest() = default; + /// Getter and Setter of the state State GetState() const { return st_; } void SetState(State newState) { st_ = newState; } - Status Wait(TensorRow *out); - void WakeUpAny(TensorRow &&row); - Status Release(TensorRow *out); + /// Take a tensor row and send rpc call to the server async + /// \param cc Cache client of the CacheMergeOp + /// \param row TensorRow to be sent to the server + /// \return Status object + /// \note Thread safe + Status AsyncSendCacheRequest(const std::shared_ptr &cc, const TensorRow &row); + + /// \brief We send the row to the server async so the CacheMissWorkerEntry can continue. + /// It is the cleaner that will check the result. + /// \return Status object + Status CheckCacheResult(); private: - std::mutex dq_mux_; std::atomic st_; - Semaphore use_count_; - std::deque row_; - TensorRow cleaner_copy_; + std::shared_ptr cleaner_copy_; }; constexpr static int kCacheHitChildIdx = 0; // Cache hit stream @@ -80,6 +88,8 @@ class CacheMergeOp : public ParallelOp { /// \return Builder setter method returns reference to the builder. Builder &SetNumWorkers(int32_t num_workers) { build_num_workers_ = num_workers; + // Adjust the number of cleaners to match the number of workers + build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_); return *this; } @@ -159,7 +169,6 @@ class CacheMergeOp : public ParallelOp { /// \param workerId /// \return Status object Status CacheMissWorkerEntry(int32_t workerId); - Status GetRq(row_id_type row_id, TensorRowRequest **); /// \brief Base-class override for NodePass pre-visit acceptor /// \param[in] p The node to visit @@ -188,11 +197,18 @@ class CacheMergeOp : public ParallelOp { private: std::mutex mux_; - std::map>> cache_miss_map_; + QueueMap cache_miss_; + std::map>> io_request_; std::unique_ptr> io_que_; std::shared_ptr cache_client_; int32_t num_cleaners_; + /// \brief Locate the cache request from the io_request_ map + /// \param row_id + /// \param out pointer to the cache request + /// \return Status object + Status GetRq(row_id_type row_id, TensorRowCacheRequest **out); + /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for /// moving cache miss TensorRow into the CacheServer. /// \return Status object diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index c742d82522..8971841a23 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -142,7 +142,7 @@ Status CacheOp::WaitForCachingAllRows() { } // Get statistics from the server, and if we are not the one to create the cache, // wait until the state changed from build phase to fetch base. - CacheClient::ServiceStat stat{}; + CacheServiceStat stat{}; bool BuildPhaseDone = true; do { RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); @@ -157,6 +157,7 @@ Status CacheOp::WaitForCachingAllRows() { MS_LOG(INFO) << "Number of rows cached: " << num_rows_; MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; + MS_LOG(INFO) << "Average cache size : " << stat.avg_cache_sz; // Now all rows are cached and we have done a sync point check up. Next phase is // is pick up fetch input from sampler and pass up to the caller. RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index f866885f7f..881f2aff30 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -392,6 +392,13 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), ""); ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), ""); + // Filter out tcp/ip information + ss_str = std::regex_replace(ss_str, std::regex("Hostname.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("Port.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("Number of rpc workers.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("Prefetch size.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("Local client support.*\n"), ""); + // Filter out Number of rows when generating the check sum ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), ""); diff --git a/mindspore/ccsrc/minddata/dataset/include/status.h b/mindspore/ccsrc/minddata/dataset/include/status.h index b919b4dc4e..bb6a787914 100644 --- a/mindspore/ccsrc/minddata/dataset/include/status.h +++ b/mindspore/ccsrc/minddata/dataset/include/status.h @@ -73,6 +73,7 @@ enum class StatusCode : char { kProfilingError = 10, kBoundingBoxOutOfBounds = 11, kBoundingBoxInvalidShape = 12, + kSyntaxError = 13, // Make this error code the last one. Add new error code above it. kUnexpectedError = 127 }; diff --git a/mindspore/ccsrc/minddata/dataset/util/allocator.h b/mindspore/ccsrc/minddata/dataset/util/allocator.h index 8c64c2940e..82bb157052 100644 --- a/mindspore/ccsrc/minddata/dataset/util/allocator.h +++ b/mindspore/ccsrc/minddata/dataset/util/allocator.h @@ -168,9 +168,9 @@ class MemGuard { size_t GetSizeInBytes() const { return n_ * sizeof(T); } private: + size_t n_; allocator alloc_; std::unique_ptr ptr_; - size_t n_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.h b/mindspore/ccsrc/minddata/dataset/util/arena.h index 2525bde142..c9c83d027b 100644 --- a/mindspore/ccsrc/minddata/dataset/util/arena.h +++ b/mindspore/ccsrc/minddata/dataset/util/arena.h @@ -27,20 +27,20 @@ #define ARENA_WALL_OVERHEAD_SZ 32 namespace mindspore { namespace dataset { -// This is a memory arena based on a treap data structure. -// The constructor of the Arena takes the size of the initial memory size (in MB). -// Internally we divide the memory into multiple blocks. Each block is 64 bytes. -// The treap contains all the free blocks with the relative memory address as key -// and the size of the block as priority. -// -// Initially the treap has only one root which is the whole memory piece. -// -// For memory suballocation, we pop the root node of the treap which contains the largest free block. -// We allocate what we need and return the rest back to the treap. We search for the first fit instead -// of the best fit so to give us a constant time in memory allocation. -// -// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to -// form a bigger block. +/// This is a memory arena based on a treap data structure. +/// The constructor of the Arena takes the size of the initial memory size (in MB). +/// Internally we divide the memory into multiple blocks. Each block is 64 bytes. +/// The treap contains all the free blocks with the relative memory address as key +/// and the size of the block as priority. +/// +/// Initially the treap has only one root which is the whole memory piece. +/// +/// For memory suballocation, we pop the root node of the treap which contains the largest free block. +/// We allocate what we need and return the rest back to the treap. We search for the first fit instead +/// of the best fit so to give us a constant time in memory allocation. +/// +/// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to +/// form a bigger block. class Arena : public MemoryPool { public: Arena(const Arena &) = delete; @@ -78,7 +78,7 @@ class Arena : public MemoryPool { static Status CreateArena(std::shared_ptr *p_ba, size_t val_in_MB = 4096); - private: + protected: std::mutex mux_; Treap tr_; void *ptr_; diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc index 37c6107fb0..18016d6cea 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc @@ -140,13 +140,22 @@ Path CachePool::GetSpillPath() const { } CachePool::CacheStat CachePool::GetStat() const { CacheStat cs{0}; + int64_t total_sz = 0; for (auto &it : *tree_) { + total_sz += it.sz; if (it.ptr != nullptr) { ++cs.num_mem_cached; } else { ++cs.num_disk_cached; } } + if (total_sz > 0) { + // integer arithmetic. NO need to cast to float or double. + cs.average_cache_sz = total_sz / (cs.num_disk_cached + cs.num_mem_cached); + if (cs.average_cache_sz == 0) { + cs.average_cache_sz = 1; + } + } return cs; } Status CachePool::Spill(CachePool::DataLocator *dl) { diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h index 9bed5a2ef3..3989941a33 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h @@ -82,6 +82,7 @@ class CachePool : public Service { struct CacheStat { int64_t num_mem_cached; int64_t num_disk_cached; + int64_t average_cache_sz; }; /// \brief Constructor diff --git a/mindspore/ccsrc/minddata/dataset/util/queue_map.h b/mindspore/ccsrc/minddata/dataset/util/queue_map.h new file mode 100644 index 0000000000..3951ec14ce --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/queue_map.h @@ -0,0 +1,127 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/semaphore.h" +#include "minddata/dataset/util/services.h" +namespace mindspore { +namespace dataset { +template +/// \brief QueueMap is like a Queue but instead of there is a map of deque. +/// Consumer will block if the corresponding deque is empty. +/// Producer can add an element of type T with key of type K to the map and +/// wake up any waiting consumer. +/// \tparam K key type +/// \tparam T payload of the map +class QueueMap { + public: + using key_type = K; + using value_type = T; + + QueueMap() = default; + virtual ~QueueMap() = default; + + /// Add an element to the map and wake up any consumer that is waiting + /// \param key + /// \param payload + /// \return Status object + virtual Status Add(key_type key, T &&payload) { + RequestQueue *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(key, &rq)); + RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload))); + return Status::OK(); + } + + /// Pop the front of the deque with key. Block if the deque is empty. + virtual Status PopFront(key_type key, T *out) { + RequestQueue *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(key, &rq)); + RETURN_IF_NOT_OK(rq->Wait(out)); + return Status::OK(); + } + + protected: + /// This is a handshake structure between producer and consumer + class RequestQueue { + public: + RequestQueue() : use_count_(0) {} + ~RequestQueue() = default; + + Status Wait(T *out) { + RETURN_UNEXPECTED_IF_NULL(out); + // Block until the missing row is in the pool. + RETURN_IF_NOT_OK(use_count_.P()); + std::unique_lock lck(dq_mux_); + CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); + *out = std::move(row_.front()); + row_.pop_front(); + return Status::OK(); + } + + Status WakeUpAny(T &&row) { + std::unique_lock lck(dq_mux_); + row_.push_back(std::move(row)); + // Bump up the use count by 1. This wake up any parallel worker which is waiting + // for this row. + use_count_.V(); + return Status::OK(); + } + + private: + std::mutex dq_mux_; + Semaphore use_count_; + std::deque row_; + }; + + /// Create or locate an element with matching key + /// \param key + /// \param out + /// \return Status object + Status GetRq(key_type key, RequestQueue **out) { + RETURN_UNEXPECTED_IF_NULL(out); + std::unique_lock lck(mux_); + auto it = all_.find(key); + if (it != all_.end()) { + *out = it->second.GetMutablePointer(); + } else { + // We will create a new one. + auto alloc = Services::GetAllocator(); + auto r = all_.emplace(key, MemGuard>(alloc)); + if (r.second) { + auto &mem = r.first->second; + RETURN_IF_NOT_OK(mem.allocate(1)); + *out = mem.GetMutablePointer(); + } else { + RETURN_STATUS_UNEXPECTED("Map insert fail."); + } + } + return Status::OK(); + } + + private: + std::mutex mux_; + std::map>> all_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/services.cc b/mindspore/ccsrc/minddata/dataset/util/services.cc index 44eba24ca6..882dc2628f 100644 --- a/mindspore/ccsrc/minddata/dataset/util/services.cc +++ b/mindspore/ccsrc/minddata/dataset/util/services.cc @@ -22,7 +22,6 @@ #include #endif #include -#include "minddata/dataset/engine/cache/cache_server.h" #include "minddata/dataset/util/circular_pool.h" #include "minddata/dataset/util/random.h" #include "minddata/dataset/util/task_manager.h" @@ -59,35 +58,15 @@ std::string Services::GetUniqueID() { return std::string(buffer, UNIQUEID_LEN); } -TaskManager &Services::getTaskMgrInstance() { - Services &sm = GetInstance(); - return *(static_cast(sm.sa_[kSlotTaskMgr_])); -} - -CacheServer &Services::getCacheServer() { - Services &sm = GetInstance(); - return *(static_cast(sm.sa_[kSlotCacheMgr_])); -} - Status Services::CreateAllInstances() { - // In order, TaskMgr, BufferMgr - Status rc; - sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); - RETURN_IF_NOT_OK(rc); - rc = sa_[kSlotTaskMgr_]->ServiceStart(); - RETURN_IF_NOT_OK(rc); - // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers -#if !defined(_WIN32) && !defined(_WIN64) - sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); - RETURN_IF_NOT_OK(rc); - rc = sa_[kSlotCacheMgr_]->ServiceStart(); -#else - sa_[kSlotCacheMgr_] = nullptr; -#endif - return rc; + // First one is always the TaskManager + RETURN_IF_NOT_OK(TaskManager::CreateInstance()); + TaskManager &tm = TaskManager::GetInstance(); + RETURN_IF_NOT_OK(tm.ServiceStart()); + return Status::OK(); } -Services::Services() : pool_(nullptr), sa_{nullptr} { +Services::Services() : pool_(nullptr) { Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M if (rc.IsError()) { std::terminate(); @@ -95,22 +74,11 @@ Services::Services() : pool_(nullptr), sa_{nullptr} { } Services::~Services() noexcept { - try { - // In reverse order - CacheServer *cs = static_cast(sa_[kSlotCacheMgr_]); - if (cs != nullptr) { - (void)cs->ServiceStop(); - cs->~CacheServer(); - pool_->Deallocate(cs); - } - TaskManager *tm = static_cast(sa_[kSlotTaskMgr_]); - if (tm != nullptr) { - (void)tm->ServiceStop(); - tm->~TaskManager(); - pool_->Deallocate(tm); - } - } catch (const std::exception &e) { - // Do nothing. + // Shutdown in reverse order. + auto n = hook_.size(); + while (n > 0) { + hook_.pop_back(); + n = hook_.size(); } } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/util/services.h b/mindspore/ccsrc/minddata/dataset/util/services.h index 9d4dca9765..a077f670cb 100644 --- a/mindspore/ccsrc/minddata/dataset/util/services.h +++ b/mindspore/ccsrc/minddata/dataset/util/services.h @@ -16,9 +16,11 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ +#include #include #include #include +#include #include "minddata/dataset/util/memory_pool.h" #include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/service.h" @@ -27,7 +29,7 @@ namespace mindspore { namespace dataset { class TaskManager; -class CacheServer; + class Services { public: static Status CreateInstance() { @@ -59,10 +61,6 @@ class Services { ~Services() noexcept; - static TaskManager &getTaskMgrInstance(); - - static CacheServer &getCacheServer(); - std::shared_ptr GetServiceMemPool() { return pool_; } #if !defined(_WIN32) && !defined(_WIN64) @@ -80,19 +78,29 @@ class Services { return Allocator(Services::GetInstance().GetServiceMemPool()); } + /// \brief Add a new service to the start up list. + /// \tparam T Class that implements Service + /// \return Status object and where the service is located in the hook_ list + template + Status AddHook(T **out, Args &&... args) { + RETURN_UNEXPECTED_IF_NULL(out); + try { + (*out) = new T(std::forward(args)...); + std::unique_ptr svc(*out); + hook_.push_back(std::move(svc)); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } + return Status::OK(); + } + private: static std::once_flag init_instance_flag_; static std::unique_ptr instance_; // A small pool used for small objects that last until the // Services Manager shuts down. Used by all sub-services. std::shared_ptr pool_; - // We use pointers here instead of unique_ptr because we - // want to have ultimate control on the order of - // construction and destruction. - static constexpr int kSlotTaskMgr_ = 0; - static constexpr int kSlotCacheMgr_ = 1; - static constexpr int kNumServices_ = 2; - Service *sa_[kNumServices_]; + std::vector> hook_; Services(); diff --git a/mindspore/ccsrc/minddata/dataset/util/slice.h b/mindspore/ccsrc/minddata/dataset/util/slice.h index 304a7e8698..058b822332 100644 --- a/mindspore/ccsrc/minddata/dataset/util/slice.h +++ b/mindspore/ccsrc/minddata/dataset/util/slice.h @@ -86,6 +86,7 @@ class ReadableSlice { class WritableSlice : public ReadableSlice { public: friend class StorageContainer; + friend class CacheService; /// \brief Default constructor WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} /// \brief This form of a constructor takes a pointer and its size. diff --git a/mindspore/ccsrc/minddata/dataset/util/status.cc b/mindspore/ccsrc/minddata/dataset/util/status.cc index 9d60bfe6a6..438686fa3f 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.cc +++ b/mindspore/ccsrc/minddata/dataset/util/status.cc @@ -48,6 +48,9 @@ std::string CodeAsString(const StatusCode c) { case StatusCode::kProfilingError: s = "Error encountered while profiling"; break; + case StatusCode::kSyntaxError: + s = "Syntax error"; + break; case StatusCode::kUnexpectedError: default: s = "Unexpected error"; diff --git a/mindspore/ccsrc/minddata/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h index b5500c1013..1c822195c4 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.h +++ b/mindspore/ccsrc/minddata/dataset/util/status.h @@ -80,6 +80,7 @@ enum class StatusCode : char { kProfilingError = 10, kBoundingBoxOutOfBounds = 11, kBoundingBoxInvalidShape = 12, + kSyntaxError = 13, // Make this error code the last one. Add new error code above it. kUnexpectedError = 127 }; diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.cc b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc index e72fed5d07..38019155a2 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc @@ -21,6 +21,8 @@ namespace mindspore { namespace dataset { +TaskManager *TaskManager::instance_ = nullptr; +std::once_flag TaskManager::init_instance_flag_; // This takes the same parameter as Task constructor. Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, Task **task) { diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.h b/mindspore/ccsrc/minddata/dataset/util/task_manager.h index 7b81bc8c71..2a3e3f07c8 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task_manager.h +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.h @@ -54,7 +54,16 @@ class TaskManager : public Service { TaskManager &operator=(const TaskManager &) = delete; - static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); } + static Status CreateInstance() { + std::call_once(init_instance_flag_, [&]() -> Status { + auto &svcManager = Services::GetInstance(); + RETURN_IF_NOT_OK(svcManager.AddHook(&instance_)); + return Status::OK(); + }); + return Status::OK(); + } + + static TaskManager &GetInstance() noexcept { return *instance_; } Status DoServiceStart() override; @@ -96,6 +105,8 @@ class TaskManager : public Service { Status WatchDog(); private: + static std::once_flag init_instance_flag_; + static TaskManager *instance_; RWLock lru_lock_; SpinLock free_lock_; SpinLock tg_lock_; diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index d140a0cb55..32a9829349 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -25,15 +25,22 @@ class DatasetCache: A client to interface with tensor caching service """ - def __init__(self, session_id=None, size=0, spilling=False): + def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20): check_uint32(session_id, "session_id") check_uint64(size, "size") type_check(spilling, (bool,), "spilling") + check_uint32(port, "port") + check_uint32(prefetch_size, "prefetch size") self.session_id = session_id self.size = size self.spilling = spilling - self.cache_client = CacheClient(session_id, size, spilling) + self.port = port + self.prefetch_size = prefetch_size + self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size) + + def GetStat(self): + return self.cache_client.GetStat() def __deepcopy__(self, memodict): if id(self) in memodict: @@ -44,5 +51,7 @@ class DatasetCache: new_cache.session_id = copy.deepcopy(self.session_id, memodict) new_cache.spilling = copy.deepcopy(self.spilling, memodict) new_cache.size = copy.deepcopy(self.size, memodict) + new_cache.port = copy.deepcopy(self.port, memodict) + new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) new_cache.cache_client = self.cache_client return new_cache diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index 26db41ef66..d408ae3e72 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -43,13 +43,18 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting { } }; -TEST_F(MindDataTestCacheOp, TestCacheServer) { +TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { Status rc; - CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true + CacheClient::Builder builder; + // use arbitrary session of 1, size of 0, spilling// is true + builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); + std::shared_ptr myClient; + rc = builder.Build(&myClient); + ASSERT_TRUE(rc.IsOk()); // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. - rc = myClient.CreateCache(1, true); - EXPECT_TRUE(rc.IsOk()); - std::cout << myClient << std::endl; + rc = myClient->CreateCache(1, true); + ASSERT_TRUE(rc.IsOk()); + std::cout << *myClient << std::endl; // Create a schema using the C api's int32_t rank = 0; // not used @@ -68,11 +73,11 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { std::unordered_map map; rc = testSchema->GetColumnNameMap(&map); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Test the CacheSchema api - rc = myClient.CacheSchema(map); - EXPECT_TRUE(rc.IsOk()); + rc = myClient->CacheSchema(map); + ASSERT_TRUE(rc.IsOk()); // Create a tensor, take a snapshot and restore it back, and compare. std::shared_ptr t; @@ -88,48 +93,54 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { TensorRow row; row.push_back(t); int64_t row_id; - rc = myClient.WriteRow(row, &row_id); - EXPECT_TRUE(rc.IsOk()); + rc = myClient->WriteRow(row, &row_id); + ASSERT_TRUE(rc.IsOk()); // Switch off build phase. - rc = myClient.BuildPhaseDone(); - EXPECT_TRUE(rc.IsOk()); + rc = myClient->BuildPhaseDone(); + ASSERT_TRUE(rc.IsOk()); // Now restore from cache. row.clear(); - rc = myClient.GetRows({row_id}, &tbl); + rc = myClient->GetRows({row_id}, &tbl); row = tbl.front(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); auto r = row.front(); std::cout << *r << std::endl; // Compare bool cmp = (*t == *r); - EXPECT_TRUE(cmp); + ASSERT_TRUE(cmp); // Get back the schema and verify std::unordered_map map_out; - rc = myClient.FetchSchema(&map_out); - EXPECT_TRUE(rc.IsOk()); + rc = myClient->FetchSchema(&map_out); + ASSERT_TRUE(rc.IsOk()); cmp = (map_out == map); - EXPECT_TRUE(cmp); + ASSERT_TRUE(cmp); // Test Purge and Destroy - rc = myClient.PurgeCache(); - EXPECT_TRUE(rc.IsOk()); - rc = myClient.DestroyCache(); - EXPECT_TRUE(rc.IsOk()); + rc = myClient->PurgeCache(); + ASSERT_TRUE(rc.IsOk()); + rc = myClient->DestroyCache(); + ASSERT_TRUE(rc.IsOk()); } -TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { +TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { // Clear the rc of the master thread if any (void)TaskManager::GetMasterThreadRc(); TaskGroup vg; Status rc; - CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true + // use arbitrary session of 1, size 1, spilling is true + CacheClient::Builder builder; + // use arbitrary session of 1, size of 0, spilling// is true + builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true); + std::shared_ptr myClient; + rc = builder.Build(&myClient); + ASSERT_TRUE(rc.IsOk()); // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. - rc = myClient.CreateCache(1, true); - EXPECT_TRUE(rc.IsOk()); - std::cout << myClient << std::endl; + rc = myClient->CreateCache(1, true); + ASSERT_TRUE(rc.IsOk()); + std::cout << *myClient << std::endl; std::shared_ptr t; Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); t->SetItemAt({0, 0}, 1); @@ -146,19 +157,19 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status { TaskManager::FindMe()->Post(); for (auto i = 0; i < 500; i++) { - RETURN_IF_NOT_OK(myClient.WriteRow(row)); + RETURN_IF_NOT_OK(myClient->WriteRow(row)); } return Status::OK(); }); - EXPECT_TRUE(vg_rc.IsOk()); + ASSERT_TRUE(vg_rc.IsOk()); } ASSERT_TRUE(vg.join_all().IsOk()); ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk()); - rc = myClient.BuildPhaseDone(); + rc = myClient->BuildPhaseDone(); ASSERT_TRUE(rc.IsOk()); // Get statistics from the server. - CacheClient::ServiceStat stat{}; - rc = myClient.GetStat(&stat); + CacheServiceStat stat{}; + rc = myClient->GetStat(&stat); ASSERT_TRUE(rc.IsOk()); std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached << "\n"; @@ -168,15 +179,15 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) { tbl.clear(); row.clear(); - rc = myClient.GetRows({i}, &tbl); - EXPECT_TRUE(rc.IsOk()); + rc = myClient->GetRows({i}, &tbl); + ASSERT_TRUE(rc.IsOk()); row = tbl.front(); auto r = row.front(); bool cmp = (*t == *r); - EXPECT_TRUE(cmp); + ASSERT_TRUE(cmp); } - rc = myClient.DestroyCache(); - EXPECT_TRUE(rc.IsOk()); + rc = myClient->DestroyCache(); + ASSERT_TRUE(rc.IsOk()); } // Simple test with a repeated cache op over random data producer @@ -187,7 +198,7 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { // | // RandomDataOp // -TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { +TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { Status rc; int32_t rank = 0; // not used MS_LOG(INFO) << "UT test TestRandomDataCache1"; @@ -218,13 +229,18 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { .SetDataSchema(std::move(testSchema)) .SetTotalRows(50) // 50 samples for now .Build(&myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // CacheOp // size of 0, spilling is true - std::shared_ptr myClient = std::make_shared(1, 0, true); + CacheClient::Builder builder; + // use arbitrary session of 1, size of 0, spilling// is true + builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); + std::shared_ptr myClient; + rc = builder.Build(&myClient); + ASSERT_TRUE(rc.IsOk()); std::shared_ptr myCacheOp; int64_t num_samples = 0; @@ -236,29 +252,29 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { .SetRowsPerBuffer(4) .SetSampler(std::move(seq_sampler)) .Build(&myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // RepeatOp uint32_t numRepeats = 4; std::shared_ptr myRepeatOp; rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Assign tree relations and root rc = myRepeatOp->AddChild(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myCacheOp->AddChild(myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssignRoot(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; rc = myTree->Prepare(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // quick check to see what tree looks like std::ostringstream ss; @@ -268,24 +284,24 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { std::cout << *myClient << std::endl; rc = myTree->Launch(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Start the loop of reading tensors from our pipeline DatasetIterator dI(myTree); TensorRow tensorList; rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); int rowCount = 0; while (!tensorList.empty()) { // Don't display these rows, just count them MS_LOG(INFO) << "Row fetched #: " << rowCount; rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rowCount++; } ASSERT_EQ(rowCount, 200); rc = myClient->DestroyCache(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); } //// Simple test with a repeated cache op over random data producer. @@ -297,7 +313,7 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { //// | //// RandomDataOp //// -TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { +TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { Status rc; int32_t rank = 0; // not used MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; @@ -328,15 +344,20 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { .SetDataSchema(std::move(testSchema)) .SetTotalRows(10) .Build(&myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // CacheOp int64_t num_samples = 0; int64_t start_index = 0; auto seq_sampler = std::make_shared(num_samples, start_index); - std::shared_ptr myClient = std::make_shared(1, 4, true); + CacheClient::Builder builder; + // use arbitrary session of 1, size of 0, spilling// is true + builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); + std::shared_ptr myClient; + rc = builder.Build(&myClient); + ASSERT_TRUE(rc.IsOk()); std::shared_ptr myCacheOp; rc = CacheOp::Builder() .SetNumWorkers(4) @@ -344,60 +365,65 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { .SetRowsPerBuffer(3) .SetSampler(std::move(seq_sampler)) .Build(&myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // RepeatOp uint32_t numRepeats = 4; std::shared_ptr myRepeatOp; rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Assign tree relations and root rc = myRepeatOp->AddChild(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myCacheOp->AddChild(myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssignRoot(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; rc = myTree->Prepare(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); std::cout << *myClient << std::endl; rc = myTree->Launch(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Start the loop of reading tensors from our pipeline DatasetIterator dI(myTree); TensorRow tensorList; rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); int rowCount = 0; while (!tensorList.empty()) { // Don't display these rows, just count them MS_LOG(INFO) << "Row fetched #: " << rowCount; rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rowCount++; } ASSERT_EQ(rowCount, 40); rc = myClient->DestroyCache(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); } -TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { +TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { Status rc; int64_t num_samples = 0; int64_t start_index = 0; auto seq_sampler = std::make_shared(num_samples, start_index); - std::shared_ptr myClient = std::make_shared(1, 0, true); + CacheClient::Builder ccbuilder; + // use arbitrary session of 1, size of 0, spilling// is true + ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); + std::shared_ptr myClient; + rc = ccbuilder.Build(&myClient); + ASSERT_TRUE(rc.IsOk()); // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by @@ -417,44 +443,44 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { .SetRecursive(true) .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); rc = builder.Build(&so); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // RepeatOp uint32_t numRepeats = 4; std::shared_ptr myRepeatOp; rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); auto myTree = std::make_shared(); rc = myTree->AssociateNode(so); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssignRoot(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myRepeatOp->AddChild(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myCacheOp->AddChild(so); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->Prepare(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->Launch(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Start the loop of reading tensors from our pipeline DatasetIterator dI(myTree); TensorRow tensorList; rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); int rowCount = 0; while (!tensorList.empty()) { rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); if (rc.IsError()) { std::cout << rc << std::endl; break; @@ -464,7 +490,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { ASSERT_EQ(rowCount, 176); std::cout << "Row count : " << rowCount << std::endl; rc = myClient->DestroyCache(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); } //// Simple test with a repeated cache op over random data producer. @@ -480,7 +506,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { //// | //// RandomDataOp //// -TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { +TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { Status rc; int32_t rank = 0; // not used MS_LOG(INFO) << "UT test TestCacheInheritSampler"; @@ -517,57 +543,62 @@ TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { .SetTotalRows(10) .SetSampler(std::move(seq_sampler)) .Build(&myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // CacheOp - std::shared_ptr myClient = std::make_shared(1, 4, true); + CacheClient::Builder ccbuilder; + // use arbitrary session of 1, size of 0, spilling// is true + ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); + std::shared_ptr myClient; + rc = ccbuilder.Build(&myClient); + ASSERT_TRUE(rc.IsOk()); std::shared_ptr myCacheOp; rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // RepeatOp uint32_t numRepeats = 4; std::shared_ptr myRepeatOp; rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Assign tree relations and root rc = myRepeatOp->AddChild(myCacheOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myCacheOp->AddChild(myRandomDataOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rc = myTree->AssignRoot(myRepeatOp); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; rc = myTree->Prepare(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); std::cout << *myClient << std::endl; rc = myTree->Launch(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); // Start the loop of reading tensors from our pipeline DatasetIterator dI(myTree); TensorRow tensorList; rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); int rowCount = 0; while (!tensorList.empty()) { // Don't display these rows, just count them MS_LOG(INFO) << "Row fetched #: " << rowCount; rc = dI.FetchNextTensorRow(&tensorList); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); rowCount++; } ASSERT_EQ(rowCount, 40); rc = myClient->DestroyCache(); - EXPECT_TRUE(rc.IsOk()); + ASSERT_TRUE(rc.IsOk()); } diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 154a4208a0..5fcc5d0866 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -15,6 +15,8 @@ """ Testing cache operator with mappable datasets """ +import os +import pytest import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision from mindspore import log as logger @@ -25,6 +27,7 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" GENERATE_GOLDEN = False +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic1(): """ Test mappable leaf with cache op right over the leaf @@ -53,7 +56,7 @@ def test_cache_map_basic1(): logger.info("test_cache_map_basic1 Ended.\n") - +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic2(): """ Test mappable leaf with the cache op later in the tree above the map(decode) @@ -82,7 +85,7 @@ def test_cache_map_basic2(): logger.info("test_cache_map_basic2 Ended.\n") - +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic3(): """ Test a repeat under mappable cache @@ -116,7 +119,7 @@ def test_cache_map_basic3(): assert num_iter == 8 logger.info('test_cache_basic3 Ended.\n') - +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic4(): """ Test different rows result in core dump @@ -141,7 +144,7 @@ def test_cache_map_basic4(): assert num_iter == 8 logger.info('test_cache_basic3 Ended.\n') - +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_failure1(): """ Test nested cache (failure) diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 4a00cc5488..010d32e370 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -15,6 +15,8 @@ """ Testing cache operator with non-mappable datasets """ +import os +import pytest import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision @@ -25,6 +27,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" GENERATE_GOLDEN = False +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic1(): """ A random dataset (a non mappable dataset) with a cache over it just after the leaf @@ -54,6 +57,7 @@ def test_cache_nomap_basic1(): logger.info("test_cache_nomap_basic1 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic2(): """ A random dataset (a non mappable dataset) with a cache over it just after the leaf @@ -85,6 +89,7 @@ def test_cache_nomap_basic2(): logger.info("test_cache_nomap_basic2 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic3(): """ A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf @@ -112,9 +117,21 @@ def test_cache_nomap_basic3(): logger.info("Number of data in ds1: {} ".format(num_iter)) assert num_iter == 12 + + # Contact the server to get the statistics + stat = some_cache.GetStat() + cache_sz = stat.avg_cache_sz + num_mem_cached = stat.num_mem_cached + num_disk_cached = stat.num_disk_cached + + logger.info("Number of rows cached in memory: {}".format(num_mem_cached)) + logger.info("Number of rows spilled to disk: {}".format(num_disk_cached)) + logger.info("Average row cache size: {}".format(cache_sz)) + logger.info("test_cache_nomap_basic3 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic4(): """ A TF reader dataset (a non mappable dataset) with a map decode and cache after it @@ -155,6 +172,7 @@ def test_cache_nomap_basic4(): logger.info("test_cache_nomap_basic4 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic5(): """ A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf @@ -191,6 +209,7 @@ def test_cache_nomap_basic5(): logger.info("test_cache_nomap_basic5 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic6(): """ A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf @@ -230,6 +249,7 @@ def test_cache_nomap_basic6(): logger.info("test_cache_nomap_basic6 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic7(): """ A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by @@ -265,6 +285,7 @@ def test_cache_nomap_basic7(): logger.info("test_cache_nomap_basic7 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_allowed_share1(): """ It is allowed to share the cache between the following two trees: @@ -280,7 +301,7 @@ def test_cache_nomap_allowed_share1(): ds.config.set_seed(1) # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True, prefetch_size=32) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) ds1 = ds1.repeat(4) @@ -300,6 +321,7 @@ def test_cache_nomap_allowed_share1(): logger.info("test_cache_nomap_allowed_share1 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_allowed_share2(): """ It is allowed to share the cache between the following two trees (with map decode): @@ -341,6 +363,7 @@ def test_cache_nomap_allowed_share2(): logger.info("test_cache_nomap_allowed_share2 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_allowed_share3(): """ It is allowed to share the cache between the following two trees (different shard ids): @@ -376,6 +399,7 @@ def test_cache_nomap_allowed_share3(): logger.info("test_cache_nomap_allowed_share3 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_allowed_share4(): """ It is allowed to share the cache between the following two trees: @@ -414,6 +438,7 @@ def test_cache_nomap_allowed_share4(): logger.info("test_cache_nomap_allowed_share4 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_disallowed_share1(): """ It is not allowed to share the cache between the following two trees: