parent
b11ef57b65
commit
8a08d0c37b
@ -1,8 +1,47 @@
|
||||
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
||||
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
||||
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
|
||||
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
add_library(engine-cache-client OBJECT
|
||||
cache_client.cc
|
||||
cache_fbb.cc
|
||||
cache_request.cc)
|
||||
add_library(engine-cache-server OBJECT
|
||||
cache_service.cc
|
||||
cache_server.cc)
|
||||
|
||||
if (ENABLE_CACHE)
|
||||
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)
|
||||
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc)
|
||||
|
||||
add_library(engine-cache-server OBJECT
|
||||
${CACHE_GRPC_SRCS}
|
||||
cache_grpc_server.cc
|
||||
cache_arena.cc
|
||||
cache_service.cc
|
||||
cache_server.cc)
|
||||
|
||||
add_executable(cache_server cache_main.cc)
|
||||
target_link_libraries(cache_server
|
||||
engine-cache-server
|
||||
$<TARGET_OBJECTS:utils>
|
||||
mindspore
|
||||
mindspore::glog
|
||||
mindspore::protobuf
|
||||
mindspore::grpc++
|
||||
mindspore_gvar
|
||||
${PYTHON_LIBRARIES}
|
||||
${SECUREC_LIBRARY}
|
||||
pthread)
|
||||
|
||||
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
|
||||
target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES} mindspore::glog)
|
||||
|
||||
add_dependencies(engine-cache-server generated_engine_files)
|
||||
|
||||
else ()
|
||||
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto)
|
||||
target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS})
|
||||
endif ()
|
||||
|
||||
add_dependencies(engine-cache-client generated_engine_files)
|
||||
|
@ -0,0 +1,70 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <unistd.h>
|
||||
#include <iostream>
|
||||
#ifdef USE_GLOG
|
||||
#include <glog/logging.h>
|
||||
#endif
|
||||
#include "minddata/dataset/engine/cache/cache_admin_arg.h"
|
||||
|
||||
namespace ds = mindspore::dataset;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ds::CacheAdminArgHandler args;
|
||||
std::stringstream arg_stream;
|
||||
|
||||
#ifdef USE_GLOG
|
||||
FLAGS_log_dir = "/tmp";
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
#endif
|
||||
|
||||
std::string warningMsg;
|
||||
warningMsg.reserve(512);
|
||||
warningMsg += "WARNING:\n";
|
||||
warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research";
|
||||
warningMsg += " purposes at this time.\n";
|
||||
warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n";
|
||||
|
||||
// A warning message until the code is mature enough.
|
||||
std::cerr << warningMsg << std::endl;
|
||||
|
||||
if (argc == 1) {
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ingest all the args into a string stream for parsing
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
arg_stream << " " << std::string(argv[i]);
|
||||
}
|
||||
|
||||
// Parse the args
|
||||
rc = args.ParseArgStream(&arg_stream);
|
||||
if (!rc.IsOk()) {
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
rc = args.RunCommand();
|
||||
if (!rc.IsOk()) {
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,105 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CacheAdminArgHandler {
|
||||
public:
|
||||
static constexpr int32_t kDefaultPort = 50052;
|
||||
static constexpr int32_t kDefaultNumWorkers = 32;
|
||||
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
|
||||
static constexpr int32_t kDefaultLogLevel = 1;
|
||||
static const char kDefaultHost[];
|
||||
static const char kServerBinary[];
|
||||
static const char kDefaultSpillDir[];
|
||||
|
||||
// These are the actual command types to execute
|
||||
enum class CommandId : int16_t {
|
||||
kCmdHelp = 0,
|
||||
kCmdStart = 1,
|
||||
kCmdStop = 2,
|
||||
kCmdGenerateSession = 3,
|
||||
kCmdDestroySession = 4,
|
||||
kCmdUnknown = 32767
|
||||
};
|
||||
|
||||
CacheAdminArgHandler();
|
||||
|
||||
~CacheAdminArgHandler() = default;
|
||||
|
||||
Status ParseArgStream(std::stringstream *arg_stream);
|
||||
|
||||
Status RunCommand();
|
||||
|
||||
void Help();
|
||||
|
||||
private:
|
||||
// These are the supported argument string integer mappings
|
||||
enum class ArgValue : int16_t {
|
||||
kArgUnknown = 0, // Must be at position 0. invalid map lookups in arg_map_ default to value 0
|
||||
kArgStart = 1,
|
||||
kArgStop = 2,
|
||||
kArgHost = 3,
|
||||
kArgPort = 4,
|
||||
kArgHelp = 5,
|
||||
kArgGenerateSession = 6,
|
||||
kArgDestroySession = 7,
|
||||
kArgSpillDir = 8,
|
||||
kArgNumWorkers = 9,
|
||||
kArgSharedMemorySize = 10,
|
||||
kArgLogLevel = 11,
|
||||
kArgNumArgs = 12 // Must be the last position to provide a count
|
||||
};
|
||||
|
||||
Status StartServer();
|
||||
|
||||
Status StopServer();
|
||||
|
||||
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id = CommandId::kCmdUnknown);
|
||||
|
||||
Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id = CommandId::kCmdUnknown);
|
||||
|
||||
Status Validate();
|
||||
|
||||
CommandId command_id_;
|
||||
int32_t port_;
|
||||
int32_t num_workers_;
|
||||
int32_t shm_mem_sz_;
|
||||
int32_t log_level_;
|
||||
session_id_type session_id_;
|
||||
std::string hostname_;
|
||||
std::string spill_dir_;
|
||||
std::string trailing_args_;
|
||||
std::map<std::string, ArgValue> arg_map_;
|
||||
std::map<ArgValue, bool> used_args_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
|
@ -0,0 +1,73 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/cache/cache_arena.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
|
||||
: Arena::Arena(val_in_GB * 1024), port_(port), shmid_(-1) {}
|
||||
|
||||
CachedSharedMemoryArena::~CachedSharedMemoryArena() {
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) {
|
||||
shmdt(this->ptr_);
|
||||
}
|
||||
this->ptr_ = nullptr;
|
||||
if (shmid_ != -1) {
|
||||
shmctl(shmid_, IPC_RMID, nullptr);
|
||||
// Also remove the path we use to generate ftok.
|
||||
Path p(PortToUnixSocketPath(port_));
|
||||
(void)p.Remove();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
|
||||
size_t val_in_GB) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
|
||||
if (ba == nullptr) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
// Transfer the ownership of this pointer. Any future error in the processing we will have
|
||||
// the destructor of *out to deal.
|
||||
(*out).reset(ba);
|
||||
// Generate the ftok using a combination of port.
|
||||
int err;
|
||||
auto shm_key = PortToFtok(port, &err);
|
||||
if (shm_key == (key_t)-1) {
|
||||
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
|
||||
ba->shmid_ = shmget(shm_key, ba->size_in_bytes_, IPC_CREAT | IPC_EXCL | access_mode);
|
||||
if (ba->shmid_) {
|
||||
ba->ptr_ = shmat(ba->shmid_, nullptr, 0);
|
||||
if (ba->ptr_ == reinterpret_cast<void *>(-1)) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
uint64_t num_blks = ba->size_in_bytes_ / ARENA_BLK_SZ;
|
||||
MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << ".";
|
||||
ba->tr_.Insert(0, num_blks);
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,52 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "minddata/dataset/util/arena.h"
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// This is a derived class of Arena but resides in shared memory
|
||||
class CachedSharedMemoryArena : public Arena {
|
||||
public:
|
||||
~CachedSharedMemoryArena() override;
|
||||
/// \brief Create an Arena in shared memory
|
||||
/// \param[out] p_ba Pointer to a unique_ptr
|
||||
/// \param shmkey Shared memory key
|
||||
/// \param val_in_GB size of shared memory in gigabyte
|
||||
/// \return Status object
|
||||
static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB);
|
||||
|
||||
/// \brief This returns where we attach to the shared memory.
|
||||
/// Some gRPC requests will ask for a shared memory block, and
|
||||
/// we can't return the absolute address as this makes no sense
|
||||
/// in the client. So instead we will return an address relative
|
||||
/// to the base address of the shared memory where we attach to.
|
||||
/// \return Base address of the shared memory.
|
||||
const void *SharedMemoryBaseAddr() const { return this->ptr_; }
|
||||
|
||||
private:
|
||||
int32_t port_;
|
||||
int shmid_;
|
||||
/// Private constructor. Not to be called directly.
|
||||
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
|
||||
|
||||
/// \note This header file contains common header files and some inlines used by
|
||||
/// both client and server side codes. Do not put code that is not common here.
|
||||
/// There are client and server specific header files.
|
||||
|
||||
// On platform like Windows, we may support only tcp/ip clients
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#define CACHE_LOCAL_CLIENT 1
|
||||
#endif
|
||||
|
||||
#ifdef CACHE_LOCAL_CLIENT
|
||||
#include <sys/types.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/shm.h>
|
||||
#else
|
||||
typedef int key_t;
|
||||
#endif
|
||||
#ifdef ENABLE_CACHE
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#endif
|
||||
#include <string>
|
||||
#ifdef ENABLE_CACHE
|
||||
#include "proto/cache_grpc.grpc.pb.h"
|
||||
#endif
|
||||
#include "proto/cache_grpc.pb.h"
|
||||
#include "minddata/dataset/engine/cache/cache_request.h"
|
||||
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported
|
||||
/// on the platform) when the amount of bytes sent is greater than the following number.
|
||||
/// For too small amount, we won't get any benefit using shared memory method because we need
|
||||
/// two rpc requests to use shared memory method.
|
||||
constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
|
||||
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
|
||||
constexpr static uint32_t kLocalClientSupport = 1;
|
||||
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
|
||||
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
|
||||
constexpr static uint32_t kDataIsInSharedMemory = 2;
|
||||
|
||||
/// \brief Convert a Status object into a protobuf
|
||||
/// \param rc[in] Status object
|
||||
/// \param reply[in/out] pointer to pre-allocated protobuf object
|
||||
inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
|
||||
reply->set_rc(static_cast<google::int32>(rc.get_code()));
|
||||
reply->set_msg(rc.ToString());
|
||||
}
|
||||
|
||||
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
|
||||
/// \param port
|
||||
/// \return unix socket url
|
||||
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }
|
||||
|
||||
/// \brief Generate a shared memory key using the tcp/ip port.
|
||||
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
|
||||
/// \note Caller must check the return value. -1 means ftok failed.
|
||||
/// \param[in] port
|
||||
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
|
||||
/// \return key
|
||||
inline key_t PortToFtok(int port, int *err) {
|
||||
key_t shmkey = -1;
|
||||
#ifdef CACHE_LOCAL_CLIENT
|
||||
const std::string unix_path = PortToUnixSocketPath(port);
|
||||
shmkey = ftok(unix_path.data(), 'a');
|
||||
if (err != nullptr && shmkey == (key_t)-1) {
|
||||
*err = errno;
|
||||
}
|
||||
#endif
|
||||
return shmkey;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
|
@ -0,0 +1,151 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/cache/cache_fbb.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
|
||||
/// \note Not to be called by outside world
|
||||
/// \return Status object
|
||||
Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
|
||||
const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_off);
|
||||
const Tensor *ts = ts_ptr.get();
|
||||
auto shape_off = fbb->CreateVector(ts->shape().AsVector());
|
||||
const auto ptr = ts->GetBuffer();
|
||||
if (ptr == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
|
||||
}
|
||||
auto src = ts->type().value();
|
||||
TensorType dest;
|
||||
#define CASE(t) \
|
||||
case DataType::t: \
|
||||
dest = TensorType::TensorType_##t; \
|
||||
break
|
||||
// Map the type to fill in the flat buffer.
|
||||
switch (src) {
|
||||
CASE(DE_BOOL);
|
||||
CASE(DE_INT8);
|
||||
CASE(DE_UINT8);
|
||||
CASE(DE_INT16);
|
||||
CASE(DE_UINT16);
|
||||
CASE(DE_INT32);
|
||||
CASE(DE_UINT32);
|
||||
CASE(DE_INT64);
|
||||
CASE(DE_UINT64);
|
||||
CASE(DE_FLOAT16);
|
||||
CASE(DE_FLOAT32);
|
||||
CASE(DE_FLOAT64);
|
||||
CASE(DE_STRING);
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
|
||||
RETURN_STATUS_UNEXPECTED("Unknown type");
|
||||
}
|
||||
#undef CASE
|
||||
|
||||
TensorMetaMsgBuilder ts_builder(*fbb);
|
||||
ts_builder.add_dims(shape_off);
|
||||
ts_builder.add_type(dest);
|
||||
auto ts_off = ts_builder.Finish();
|
||||
*out_off = ts_off;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_fbb);
|
||||
auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
|
||||
try {
|
||||
fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
|
||||
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
|
||||
std::vector<int64_t> tensor_sz;
|
||||
v.reserve(row.size());
|
||||
tensor_sz.reserve(row.size());
|
||||
// We will go through each column in the row.
|
||||
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
|
||||
flatbuffers::Offset<TensorMetaMsg> ts_off;
|
||||
RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
|
||||
v.push_back(ts_off);
|
||||
tensor_sz.push_back(ts_ptr->SizeInBytes());
|
||||
}
|
||||
auto column_off = fbb->CreateVector(v);
|
||||
auto data_sz_off = fbb->CreateVector(tensor_sz);
|
||||
TensorRowHeaderMsgBuilder row_builder(*fbb);
|
||||
row_builder.add_column(column_off);
|
||||
row_builder.add_data_sz(data_sz_off);
|
||||
// Pass the row_id even if it may not be known.
|
||||
row_builder.add_row_id(row.getId());
|
||||
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
|
||||
auto out = row_builder.Finish();
|
||||
fbb->Finish(out);
|
||||
// Now go back to fill in size_of_this in the flat buffer.
|
||||
auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
|
||||
auto success = msg->mutate_size_of_this(fbb->GetSize());
|
||||
if (!success) {
|
||||
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
|
||||
}
|
||||
(*out_fbb) = std::move(fbb);
|
||||
return Status::OK();
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
|
||||
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(col_ts);
|
||||
auto shape_in = col_ts->dims();
|
||||
auto type_in = col_ts->type();
|
||||
std::vector<dsize_t> v;
|
||||
v.reserve(shape_in->size());
|
||||
v.assign(shape_in->begin(), shape_in->end());
|
||||
TensorShape shape(v);
|
||||
DataType::Type dest = DataType::DE_UNKNOWN;
|
||||
#define CASE(t) \
|
||||
case TensorType_##t: \
|
||||
dest = DataType::Type::t; \
|
||||
break
|
||||
|
||||
switch (type_in) {
|
||||
CASE(DE_BOOL);
|
||||
CASE(DE_INT8);
|
||||
CASE(DE_UINT8);
|
||||
CASE(DE_INT16);
|
||||
CASE(DE_UINT16);
|
||||
CASE(DE_INT32);
|
||||
CASE(DE_UINT32);
|
||||
CASE(DE_INT64);
|
||||
CASE(DE_UINT64);
|
||||
CASE(DE_FLOAT16);
|
||||
CASE(DE_FLOAT32);
|
||||
CASE(DE_FLOAT64);
|
||||
CASE(DE_STRING);
|
||||
}
|
||||
#undef CASE
|
||||
|
||||
DataType type(dest);
|
||||
std::shared_ptr<Tensor> ts;
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
|
||||
// Next we restore the real data which can be embedded or stored separately.
|
||||
if (ts->SizeInBytes() != data.GetSize()) {
|
||||
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
|
||||
<< "Dumping tensor\n"
|
||||
<< *ts << "\n";
|
||||
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
|
||||
}
|
||||
*out = std::move(ts);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
|
||||
|
||||
/// This header contains some serialize and deserialize functions for tensor row using
|
||||
/// Google Flatbuffer
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/util/slice.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief Function to serialize TensorRow header used by CacheRowRequest
|
||||
/// \param row TensorRow
|
||||
/// \param fbb [in/out] fbb that contains the serialized data
|
||||
/// \return Status object
|
||||
Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *fbb);
|
||||
|
||||
/// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row.
|
||||
/// \param col_ts A serialized version of Tensor meta data
|
||||
/// \param data Tensor data wrapped in a slice
|
||||
/// \param out Tensor
|
||||
/// \return Status object
|
||||
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
|
@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
package mindspore.dataset;
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// The session_id and crc work together to uniquely identify this particular cache and allow
|
||||
// sharing of the cache.
|
||||
message CacheClientInfo {
|
||||
uint32 session_id = 1;
|
||||
uint32 crc = 2;
|
||||
}
|
||||
|
||||
message CacheRequest {
|
||||
// Type of rpc request
|
||||
int32 type = 1;
|
||||
// Extra optional flag used by individual request if needed
|
||||
uint32 flag = 2;
|
||||
oneof connect_info {
|
||||
// The server_connection_id is the actual id we use for operations after the cache is built
|
||||
int64 connection_id = 3;
|
||||
// But some request like CreateCache we have to use the session id and crc to connect to the server.
|
||||
CacheClientInfo connection_info = 4;
|
||||
}
|
||||
// Everything else is just vector of buffers
|
||||
repeated bytes buf_data = 5;
|
||||
}
|
||||
|
||||
message CacheReply {
|
||||
int32 rc = 1;
|
||||
string msg = 2;
|
||||
// Extra optional flag used by individual request if needed
|
||||
uint32 flag = 3;
|
||||
// What the server send back is a plain buffer
|
||||
bytes result = 4;
|
||||
}
|
||||
|
||||
service CacheServerGreeter {
|
||||
rpc CacheServerRequest (CacheRequest) returns (CacheReply) {}
|
||||
}
|
@ -0,0 +1,161 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
|
||||
#include <chrono>
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
|
||||
std::unique_ptr<CacheClientRequestTag> &&tag) {
|
||||
// If there is anything extra we need to do before we send.
|
||||
RETURN_IF_NOT_OK(tag->base_rq_->Prepare());
|
||||
// One minute timeout
|
||||
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
|
||||
tag->ctx_.set_deadline(deadline);
|
||||
tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq);
|
||||
tag->rpc_->StartCall();
|
||||
// Last step is we release the ownership and transfer it to the completion queue.
|
||||
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
|
||||
auto ccReqTag = tag.release();
|
||||
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_,
|
||||
ccReqTag); // inject this object into the completion queue
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CacheClientGreeter::~CacheClientGreeter() {
|
||||
(void)ServiceStop();
|
||||
// Detach from shared memory if any
|
||||
if (shmat_addr_ != nullptr) {
|
||||
shmdt(shmat_addr_);
|
||||
shmat_addr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers)
|
||||
: num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) {
|
||||
grpc::ChannelArguments args;
|
||||
// We need to bump up the message size to unlimited. The default receiving
|
||||
// message limit is 4MB which is not big enough.
|
||||
args.SetMaxReceiveMessageSize(-1);
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
// Try connect locally to the unix_socket first as the first preference
|
||||
// Need to resolve hostname to ip address rather than to do a string compare
|
||||
if (hostname == "127.0.0.1") {
|
||||
std::string target = "unix://" + PortToUnixSocketPath(port);
|
||||
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
|
||||
} else {
|
||||
#endif
|
||||
std::string target = hostname + ":" + std::to_string(port);
|
||||
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
}
|
||||
#endif
|
||||
stub_ = CacheServerGreeter::NewStub(channel_);
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) {
|
||||
*local_bypass = false;
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
int err;
|
||||
shm_key_ = PortToFtok(port, &err);
|
||||
if (shm_key_ == (key_t)-1) {
|
||||
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
// Attach to the shared memory
|
||||
shm_id_ = shmget(shm_key_, 0, 0);
|
||||
if (shm_id_ == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
shmat_addr_ = shmat(shm_id_, nullptr, 0);
|
||||
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
*local_bypass = true;
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::DoServiceStart() {
|
||||
RETURN_IF_NOT_OK(vg_.ServiceStart());
|
||||
RETURN_IF_NOT_OK(DispatchWorkers(num_workers_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::DoServiceStop() {
|
||||
// Shutdown the queue. We don't accept any more new incomers.
|
||||
cq_.Shutdown();
|
||||
// Shutdown the TaskGroup.
|
||||
vg_.interrupt_all();
|
||||
vg_.join_all(Task::WaitFlag::kNonBlocking);
|
||||
// Drain the queue
|
||||
bool success;
|
||||
void *tag;
|
||||
while (cq_.Next(&tag, &success)) {
|
||||
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
|
||||
delete r;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
|
||||
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq));
|
||||
return tag->MakeCall(stub_.get(), &cq_, std::move(tag));
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::WorkerEntry() {
|
||||
TaskManager::FindMe()->Post();
|
||||
do {
|
||||
bool success;
|
||||
void *tag;
|
||||
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
|
||||
// Set a timeout for one second. Check for interrupt if we need to do early exit.
|
||||
auto r = cq_.AsyncNext(&tag, &success, deadline);
|
||||
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
|
||||
auto rq = reinterpret_cast<CacheClientRequestTag *>(tag);
|
||||
if (success) {
|
||||
auto &rc = rq->rc_;
|
||||
if (!rc.ok()) {
|
||||
auto error_code = rq->rc_.error_code();
|
||||
std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
|
||||
Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
|
||||
}
|
||||
// Notify the waiting thread.
|
||||
rq->Notify();
|
||||
}
|
||||
// We can now free the memory
|
||||
delete rq;
|
||||
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
|
||||
// If we are interrupted, exit. Otherwise wait again.
|
||||
RETURN_IF_INTERRUPTED();
|
||||
} else {
|
||||
// Queue is drained.
|
||||
break;
|
||||
}
|
||||
} while (true);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) {
|
||||
auto f = std::bind(&CacheClientGreeter::WorkerEntry, this);
|
||||
for (auto i = 0; i < num_workers; ++i) {
|
||||
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,102 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/util/service.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief A client view of gRPC request
|
||||
/// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC
|
||||
/// completion queue. The thread that makes the rpc request will wait on a wait post
|
||||
/// area for the reply to come back. Since this tag will be deleted from memory and
|
||||
/// we thus we need to work on a shared pointer of the BaseRequest such that its
|
||||
/// use count is at least two. Otherwise either thread will be referencing stale memory.
|
||||
/// \see CacheServerRequest
|
||||
class CacheClientRequestTag {
|
||||
public:
|
||||
friend class CacheClientGreeter;
|
||||
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {}
|
||||
~CacheClientRequestTag() = default;
|
||||
|
||||
/// \brief Make a RPC call
|
||||
/// \param stub from CacheClientGreeter
|
||||
/// \param cq from CacheClientGreeter
|
||||
/// \return Status object
|
||||
static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
|
||||
std::unique_ptr<CacheClientRequestTag> &&tag);
|
||||
|
||||
/// \brief Notify the client that a result has come back from the server
|
||||
void Notify() { base_rq_->wp_.Set(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<BaseRequest> base_rq_;
|
||||
grpc::Status rc_;
|
||||
grpc::ClientContext ctx_;
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_;
|
||||
};
|
||||
|
||||
/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
|
||||
/// \see BaseRequest
|
||||
class CacheClientGreeter : public Service {
|
||||
friend class CacheClient;
|
||||
|
||||
public:
|
||||
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers);
|
||||
~CacheClientGreeter();
|
||||
|
||||
/// Override base Service class
|
||||
Status DoServiceStart() override;
|
||||
Status DoServiceStop() override;
|
||||
|
||||
/// \brief Send the request to the server
|
||||
/// \return Status object
|
||||
Status HandleRequest(std::shared_ptr<BaseRequest> rq);
|
||||
|
||||
/// \brief A handful of threads will be handling async reply from the server
|
||||
/// \return
|
||||
Status WorkerEntry();
|
||||
|
||||
/// \brief Kick off threads to receive reply from the server
|
||||
Status DispatchWorkers(int32_t num_workers);
|
||||
|
||||
/// \brief Attach to shared memory for local client
|
||||
/// \note Called after we have established a connection.
|
||||
/// \return Status object.
|
||||
Status AttachToSharedMemory(int32_t port, bool *local_bypass);
|
||||
|
||||
/// \brief This returns where we attach to the shared memory.
|
||||
/// \return Base address of the shared memory.
|
||||
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<grpc::Channel> channel_;
|
||||
std::unique_ptr<CacheServerGreeter::Stub> stub_;
|
||||
grpc::CompletionQueue cq_;
|
||||
TaskGroup vg_;
|
||||
int32_t num_workers_;
|
||||
key_t shm_key_;
|
||||
int32_t shm_id_;
|
||||
void *shmat_addr_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
|
@ -0,0 +1,203 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
|
||||
#include <limits>
|
||||
#include "minddata/dataset/engine/cache/cache_server.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "utils/log_adapter.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
|
||||
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) {
|
||||
// Setup a path for unix socket.
|
||||
unix_socket_ = PortToUnixSocketPath(port);
|
||||
// We can't generate the ftok key yet until the unix_socket_ is created
|
||||
}
|
||||
|
||||
void CacheServerGreeterImpl::Shutdown() {
|
||||
if (server_) {
|
||||
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
|
||||
server_->Shutdown(deadline);
|
||||
}
|
||||
// Always shutdown the completion queue after the server.
|
||||
if (cq_) {
|
||||
cq_->Shutdown();
|
||||
// We need to drain the queue. All the tag is coming from
|
||||
// the Services pool which will be shutdown as well. So we
|
||||
// ignore the tag.
|
||||
void *tag;
|
||||
bool success;
|
||||
while (cq_->Next(&tag, &success)) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); }
|
||||
|
||||
Status CacheServerGreeterImpl::IpcResourceCleanup() {
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
int err;
|
||||
auto shm_key = PortToFtok(port_, &err);
|
||||
// We are expecting the unix path doesn't exist.
|
||||
if (shm_key == (key_t)-1) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Attach to the shared memory
|
||||
auto shm_id = shmget(shm_key, 0, 0);
|
||||
if (shm_id == -1) {
|
||||
return Status::OK();
|
||||
}
|
||||
struct shmid_ds ds {};
|
||||
auto inx = shmctl(shm_id, IPC_STAT, &ds);
|
||||
if (inx == -1) {
|
||||
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id);
|
||||
errMsg += "\nPlesae remove it manually using ipcrm -m command";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
if (ds.shm_nattch == 0) {
|
||||
// Stale shared memory from last time.
|
||||
// Remove both the memory and the socket path
|
||||
inx = shmctl(shm_id, IPC_RMID, nullptr);
|
||||
if (inx == -1) {
|
||||
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id);
|
||||
errMsg += ". Errno :" + std::to_string(errno);
|
||||
errMsg += "\nPlesae remove it manually using ipcrm -m command";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
Path p(unix_socket_);
|
||||
(void)p.Remove();
|
||||
} else {
|
||||
// Server is already up.
|
||||
MS_LOG(ERROR) << "Cache server is already up and running";
|
||||
// We return a duplicate error. The main() will intercept
|
||||
// and output a proper message
|
||||
return Status(StatusCode::kDuplicateKey);
|
||||
}
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServerGreeterImpl::Run() {
|
||||
// To listen on all interfaces, use 0.0.0.0
|
||||
// Use 127.0.0.1 if just locally on the same machine.
|
||||
std::string host("0.0.0.0"); // listen on all interfaces.
|
||||
std::string server_address = host + ":" + std::to_string(port_);
|
||||
grpc::ServerBuilder builder;
|
||||
// Default message size for gRPC is 4MB. Increase it to 2g-1
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
|
||||
int port_tcpip = 0;
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
int port_local = 0;
|
||||
// Check if we need to do clean up on the shared memory if the server
|
||||
// came down unexpectedly like SEGV
|
||||
RETURN_IF_NOT_OK(IpcResourceCleanup());
|
||||
// We also optimize on local clients on the same machine using unix socket
|
||||
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
|
||||
#endif
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
|
||||
builder.RegisterService(&svc_);
|
||||
cq_ = builder.AddCompletionQueue();
|
||||
server_ = builder.BuildAndStart();
|
||||
if (server_) {
|
||||
MS_LOG(INFO) << "Server listening on " << server_address;
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
|
||||
MS_LOG(INFO) << "Creation of local socket and shared memory successful";
|
||||
#endif
|
||||
} else {
|
||||
std::string errMsg = "Fail to start server. ";
|
||||
if (port_tcpip != port_) {
|
||||
errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + ".";
|
||||
}
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
if (port_local == 0) {
|
||||
errMsg += " Unable to create unix socket " + unix_socket_ + ".";
|
||||
}
|
||||
#endif
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) {
|
||||
bool success;
|
||||
void *tag;
|
||||
// We loop through the grpc queue. Each connection if successful
|
||||
// will come back with our own tag which is an instance of CacheServerRequest
|
||||
// and we simply call its functor. But first we need to create these instances
|
||||
// and inject them into the grpc queue.
|
||||
CacheServerRequest *p;
|
||||
// Get a free tag from my free list.
|
||||
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p));
|
||||
RETURN_IF_NOT_OK((*p)(&svc_, cq_.get()));
|
||||
do {
|
||||
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
|
||||
// Set a timeout for one second. Check for interrupt if we need to do early exit.
|
||||
auto r = cq_->AsyncNext(&tag, &success, deadline);
|
||||
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
|
||||
if (success) {
|
||||
auto rq = static_cast<CacheServerRequest *>(tag);
|
||||
RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get()));
|
||||
}
|
||||
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
|
||||
// If we are interrupted, exit. Otherwise wait again.
|
||||
RETURN_IF_INTERRUPTED();
|
||||
} else {
|
||||
// Queue is drained.
|
||||
break;
|
||||
}
|
||||
} while (true);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) {
|
||||
auto myQID = getQid();
|
||||
if (st_ == STATE::CREATE) {
|
||||
st_ = STATE::PROCESS;
|
||||
svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this);
|
||||
} else if (st_ == STATE::PROCESS) {
|
||||
// Get a new tag and handle the next request before we serve the current request.
|
||||
// The tag will be recycled when its state is changed to FINISH
|
||||
CacheServerRequest *next_rq;
|
||||
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq));
|
||||
RETURN_IF_NOT_OK((*next_rq)(svc, cq));
|
||||
// Now we continue with the current request.
|
||||
// First thing we need to extract the type from the incoming request.
|
||||
// When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN.
|
||||
type_ = static_cast<RequestType>(rq_.type());
|
||||
// Now we pass the address of this instance to CacheServer's main loop.
|
||||
MS_LOG(DEBUG) << "Handle request " << *this;
|
||||
auto &cs = CacheServer::GetInstance();
|
||||
RETURN_IF_NOT_OK(cs.PushRequest(myQID, this));
|
||||
} else if (st_ == STATE::FINISH) {
|
||||
MS_LOG(DEBUG) << *this << " Finished.";
|
||||
// Return back to the free list.
|
||||
RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CacheServerRequest::Print(std::ostream &out) const {
|
||||
if (rq_.has_connection_info()) {
|
||||
out << "Session Id: " << rq_.connection_info().session_id() << " CRC: " << rq_.connection_info().crc();
|
||||
} else {
|
||||
out << "Connection Id: " << rq_.connection_id();
|
||||
}
|
||||
out << " ";
|
||||
BaseRequest::Print(out);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,103 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/engine/cache/cache_arena.h"
|
||||
#include "minddata/dataset/util/allocator.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects
|
||||
/// and this class is used to translate from protobuf to structures understood by CacheService class.
|
||||
/// \see CacheService
|
||||
class CacheServerRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
|
||||
explicit CacheServerRequest(int32_t queue_id)
|
||||
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown),
|
||||
qid_(queue_id),
|
||||
st_(STATE::CREATE),
|
||||
responder_(&ctx_) {}
|
||||
|
||||
~CacheServerRequest() = default;
|
||||
|
||||
/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
|
||||
/// functor will translate each protobuf into some form understood by by CacheService class.
|
||||
/// \param svc Async service
|
||||
/// \param cq Completion queue
|
||||
/// \return Status object
|
||||
Status operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq);
|
||||
|
||||
/// \brief Override the base class Print method
|
||||
/// \param out
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Getter of the queue id
|
||||
/// \return The queue where the request should go to
|
||||
int32_t getQid() const { return qid_; }
|
||||
|
||||
private:
|
||||
int32_t qid_;
|
||||
Status rc_;
|
||||
STATE st_;
|
||||
grpc::ServerContext ctx_;
|
||||
grpc::ServerAsyncResponseWriter<CacheReply> responder_;
|
||||
};
|
||||
|
||||
/// \brief Implementation of CacheServerGreeter
|
||||
/// \note It is an async server
|
||||
/// \see cache_grpc.proto
|
||||
class CacheServerGreeterImpl final {
|
||||
friend class CacheServer;
|
||||
|
||||
public:
|
||||
explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb);
|
||||
virtual ~CacheServerGreeterImpl();
|
||||
/// \brief Brings up gRPC server
|
||||
/// \return none
|
||||
Status Run();
|
||||
/// \brief Entry function to handle cache server request
|
||||
Status HandleRequest(int32_t worker_id);
|
||||
|
||||
/// Return the shared memory pool.
|
||||
/// \return Return the shared memory pool
|
||||
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }
|
||||
|
||||
void Shutdown();
|
||||
|
||||
Status IpcResourceCleanup();
|
||||
|
||||
private:
|
||||
int32_t port_;
|
||||
size_t shm_pool_sz_in_gb_;
|
||||
std::string unix_socket_;
|
||||
CacheServerGreeter::AsyncService svc_;
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
|
@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/cache/cache_server.h"
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#ifdef USE_GLOG
|
||||
#include <glog/logging.h>
|
||||
#endif
|
||||
#include <cstdlib>
|
||||
|
||||
namespace ds = mindspore::dataset;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ds::CacheServer::Builder builder;
|
||||
|
||||
// This executable is not to be called directly, and should be invoked by cache_admin executable.
|
||||
if (argc != 7) {
|
||||
rc = ds::Status(ds::StatusCode::kSyntaxError);
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return static_cast<int>(rc.get_code());
|
||||
}
|
||||
|
||||
builder.SetRootDirectory(argv[1])
|
||||
.SetNumWorkers(strtol(argv[2], nullptr, 10))
|
||||
.SetPort(strtol(argv[3], nullptr, 10))
|
||||
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10));
|
||||
|
||||
#ifdef USE_GLOG
|
||||
FLAGS_minloglevel = strtol(argv[5], nullptr, 10);
|
||||
#endif
|
||||
|
||||
auto daemonize_string = argv[6];
|
||||
bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 ||
|
||||
strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0;
|
||||
|
||||
// We always change directory to / on unix rather than using the directory where the cache_server
|
||||
// is called. This is a standard procedure for daemonize a process on unix.
|
||||
if (chdir("/") == -1) {
|
||||
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
|
||||
std::cerr << errMsg << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Simple check of the parameters before we move on.
|
||||
rc = builder.SanityCheck();
|
||||
if (rc.IsError()) {
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return static_cast<int>(rc.get_code());
|
||||
}
|
||||
|
||||
#ifdef USE_GLOG
|
||||
FLAGS_log_dir = "/tmp";
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
#endif
|
||||
|
||||
if (daemonize) {
|
||||
// fork the child process to become the daemon
|
||||
pid_t pid = fork();
|
||||
// failed to fork
|
||||
if (pid < 0) {
|
||||
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
|
||||
std::cerr << err_msg << std::endl;
|
||||
return errno;
|
||||
} else if (pid > 0) {
|
||||
// Parent
|
||||
std::cerr << "cache server daemon process has been created as process id: " << pid
|
||||
<< "\nCheck log file for any start up error" << std::endl;
|
||||
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
|
||||
return 0;
|
||||
} else {
|
||||
// Child process will continue from here if daemonize and parent has already exited.
|
||||
// If we are running in the foreground, none of the code in block below will be run.
|
||||
pid_t sid;
|
||||
umask(0);
|
||||
sid = setsid();
|
||||
if (sid < 0) {
|
||||
MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno);
|
||||
return errno;
|
||||
}
|
||||
close(0);
|
||||
close(1);
|
||||
close(2);
|
||||
}
|
||||
}
|
||||
|
||||
// Dump the summary
|
||||
MS_LOG(INFO) << builder << std::endl;
|
||||
rc = builder.Build();
|
||||
if (rc.IsOk()) {
|
||||
ds::CacheServer &cs = ds::CacheServer::GetInstance();
|
||||
// Kick off the threads. Loop forever and never return unless error.
|
||||
rc = cs.Run();
|
||||
if (rc.get_code() == ds::StatusCode::kDuplicateKey) {
|
||||
std::string errMsg = "Server is already started";
|
||||
MS_LOG(ERROR) << errMsg;
|
||||
std::cerr << errMsg << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << rc.ToString();
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return static_cast<int>(rc.get_code());
|
||||
}
|
||||
return 0;
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue