!2891 CacheOp phase 1

Merge pull request !2891 from Jamie/CacheOp_dev
pull/2891/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit eadcb341e1

@ -47,6 +47,8 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR})
################## Include sub-modules ###############################
add_subdirectory(util)
add_subdirectory(core)
@ -55,7 +57,7 @@ add_subdirectory(engine)
add_subdirectory(api)
add_subdirectory(text)
######################################################################
add_dependencies(core utils)
add_dependencies(utils core)
add_dependencies(kernels-image core)
add_dependencies(kernels-data core)
add_dependencies(kernels core)
@ -89,6 +91,8 @@ set(submodules
$<TARGET_OBJECTS:engine-perf>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
@ -106,6 +110,8 @@ else ()
add_library(_c_dataengine SHARED ${submodules})
endif ()
add_dependencies(_c_dataengine generated_engine_files)
set_target_properties(_c_dataengine PROPERTIES
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"

File diff suppressed because it is too large Load Diff

@ -35,6 +35,8 @@ namespace mindspore {
namespace dataset {
using DsOpPtr = std::shared_ptr<DatasetOp>;
class CacheClient;
// enum for the dataset operator names
enum OpName {
kShuffle,
@ -181,6 +183,16 @@ class DEPipeline {
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
/// \brief Helper function to inject a cache operator over top of the current operation being built.
/// \param[in] cache_client The client to use for caching
/// \param[in] num_workers The number of workers to use in the cache op
/// \param[in] input_op The operator to build the cache on top of
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the cache operator
/// \return Status return code
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *cache_op);
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
/// \param[in] shuffle_size The size to use in the shuffle buffer
/// \param[in] input_op The operator to build shuffle on top of

@ -35,6 +35,7 @@
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/concatenate_op.h"
@ -768,6 +769,11 @@ void bindInfoObjects(py::module *m) {
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
}
void bindCacheClient(py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(py::init<uint32_t, uint64_t, bool>());
}
void bindVocabObjects(py::module *m) {
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
.def(py::init<>())
@ -939,6 +945,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindSamplerOps(&m);
bindDatasetOps(&m);
bindInfoObjects(&m);
bindCacheClient(&m);
bindVocabObjects(&m);
bindGraphData(&m);
bindDependIcuTokenizerOps(&m);

@ -2,6 +2,7 @@ add_subdirectory(datasetops)
add_subdirectory(opt)
add_subdirectory(gnn)
add_subdirectory(perf)
add_subdirectory(cache)
if (ENABLE_TDTQUE)
add_subdirectory(tdt)
endif ()
@ -17,7 +18,9 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf)
else()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server)
else ()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server)
endif ()

@ -0,0 +1,8 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-cache-client OBJECT
cache_client.cc
cache_request.cc)
add_library(engine-cache-server OBJECT
cache_service.cc
cache_server.cc)

@ -0,0 +1,208 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/bit.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill)
: server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {}
// print method for display cache details
void CacheClient::Print(std::ostream &out) const {
out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_
<< "\n Spilling: " << std::boolalpha << spill_;
}
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
CacheRowRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
if (row_id_from_server != nullptr) {
*row_id_from_server = rq.GetRowIdAfterCache();
}
return Status::OK();
}
Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
std::unique_ptr<DataBuffer> db_ptr = std::move(in);
auto num_rows = db_ptr->NumRows();
std::vector<TensorRow> all_rows;
if (num_rows > 0) {
all_rows.reserve(num_rows);
// Break down the DataBuffer into TensorRow. We will send the requests async
// and then do a final wait.
MemGuard<CacheRowRequest> rq_arr;
RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie()));
CacheServer &cs = CacheServer::GetInstance();
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(cs.PushRequest(rq));
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
// So park it in the vector. When this function go out of scope, its memory
// will be freed.
all_rows.push_back(std::move(row));
}
// Now we wait for the requests to be done.
for (auto i = 0; i < num_rows; ++i) {
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(rq->Wait());
}
}
return Status::OK();
}
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
BatchFetchRequest rq(server_connection_id_, row_id);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
RETURN_IF_NOT_OK(rq.RestoreRows(out));
return Status::OK();
}
Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
UniqueLock lck(&mux_);
// To create a cache, we identify ourself at the client by:
// - the shared session id
// - a crc for the tree nodes from the cache downward
// Pack these 2 into a single 64 bit request id
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch
// These are different trees in a single session, but the user wants to share the cache.
// This is not allowed because the data of these caches are different.
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch
// These are different trees in the same session, but the cached data is the same, so it is okay
// to allow the sharing of this cache between these pipelines.
// The CRC is computed by the tree prepare phase and passed to this function when creating the cache.
// If we already have a server_connection_id_, then it means this same cache client has already been used
// to create a cache and some other tree is trying to use the same cache.
// That is allowed, however the crc better match!
if (server_connection_id_) {
if (cache_crc_ != tree_crc) {
RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!");
}
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
// skip the build phase.
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheClient::ServiceStat stat{};
RETURN_IF_NOT_OK(GetStat(&stat));
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
}
} else {
cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client
// Combine the session and crc. This will form our client cache identifier.
connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_;
// Now execute the cache create request using this identifier and other configs
BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone;
if (spill_) {
createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk;
}
if (generate_id) {
createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId;
}
CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
Status rc = rq.Wait();
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
server_connection_id_ = rq.GetServerConnectionId();
if (rc.IsOk()) {
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_ = rq.cookie();
}
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
return rc;
}
return Status::OK();
}
Status CacheClient::PurgeCache() {
UniqueLock lck(&mux_);
PurgeCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
}
Status CacheClient::DestroyCache() {
UniqueLock lck(&mux_);
DestroyCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
}
Status CacheClient::GetStat(ServiceStat *stat) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(stat);
GetStatRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
stat->num_disk_cached = rq.GetNumDiskCached();
stat->num_mem_cached = rq.GetNumMemCached();
stat->min_row_id = rq.GetMinRowId();
stat->max_row_id = rq.GetMaxRowId();
stat->cache_service_state = rq.GetState();
return Status::OK();
}
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
SharedLock lck(&mux_);
CacheSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
return Status::OK();
}
Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(map);
FetchSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
*map = rq.GetColumnMap();
return Status::OK();
}
Status CacheClient::BuildPhaseDone() const {
SharedLock lck(&mux_);
BuildPhaseDoneRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,141 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_CLIENT_H_
#define DATASET_ENGINE_CACHE_CLIENT_H_
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/cache/cache_server.h"
#include "dataset/util/lock.h"
namespace mindspore {
namespace dataset {
/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through
/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously
/// rows, etc.
class CacheClient {
public:
/// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill);
/// \brief Destructor
~CacheClient() = default;
/// \brief Getter function for returning the current session id
/// \return session id
uint64_t session_id() const { return session_id_; }
/// \brief Send a TensorRow to the cache server
/// \param[in] row
/// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset
/// \return return code
Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const;
/// \brief Send a DataBuffer to the cache server
/// \param in Unique pointer of the DataBuffer to be cached
/// \return return code
Status WriteBuffer(std::unique_ptr<DataBuffer> &&in) const;
/// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is
/// any cache miss
/// \param row_id A vector of row id's
/// \param out A TensorTable of TensorRows.
/// \return return code
Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const;
/// \brief Create a cache.
/// \param tree_crc A crc that was generated during tree prepare phase
/// \param generate_id Let the cache service generate row id
/// \return Status object
Status CreateCache(uint32_t tree_crc, bool generate_id);
/// \brief Purge a cache. Cache can be reused after reset.
/// \return Status object
Status PurgeCache();
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
/// \return Status object
Status DestroyCache();
/// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object
struct ServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};
Status GetStat(ServiceStat *);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
/// \return Status object
Status CacheSchema(const std::unordered_map<std::string, int32_t> &map);
/// \brief Fetch the schema from the cache server
/// \param map Pointer to pre-allocated map object
/// \return Status object.
Status FetchSchema(std::unordered_map<std::string, int32_t> *map);
/// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache
/// client that holds cookie can be allowed to make this request
/// \return Status object
Status BuildPhaseDone() const;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
void Print(std::ostream &out) const;
/// \brief Stream output operator overload
/// \return the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) {
cc.Print(out);
return out;
}
/// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it.
/// \return Cookie
std::string cookie() const { return cookie_; }
private:
mutable RWLock mux_;
uint64_t cache_mem_sz_;
bool spill_;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache.
uint32_t session_id_;
uint32_t cache_crc_;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server.
std::string cookie_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_CLIENT_H_

@ -0,0 +1,223 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_request.h"
namespace mindspore {
namespace dataset {
Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) {
buffers_.reserve(row.size() + 1);
RETURN_IF_NOT_OK(SerializeTensorRowHeader(row));
buffers_.push_back(fbb_->GetBufferPointer());
for (const auto &ts : row) {
buffers_.push_back(ts->GetBuffer());
}
return Status::OK();
}
Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
}
auto column_off = fbb_->CreateVector(v);
auto data_sz_off = fbb_->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb_);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb_->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb_->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
}
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr,
flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb_->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
}
#undef CASE
TensorMetaMsgBuilder ts_builder(*fbb_);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}
Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data,
std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
}
#undef CASE
DataType type(dest);
std::shared_ptr<Tensor> ts =
std::make_shared<Tensor>(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize());
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}
Status BatchFetchRequest::RestoreRows(TensorTable *out) {
RETURN_UNEXPECTED_IF_NULL(out);
auto num_elements = row_id_.size();
auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer());
TensorTable tbl;
tbl.reserve(num_elements);
ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes());
for (auto i = 0; i < num_elements; ++i) {
auto len = offset_array[i + 1] - offset_array[i];
TensorRow row;
row.setId(row_id_.at(i));
if (len > 0) {
ReadableSlice row_data(all, offset_array[i], len);
// Next we de-serialize flat buffer to get back each column
auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
auto msg_sz = msg->size_of_this();
// Start of the tensor data
auto ts_offset = msg_sz;
row.reserve(msg->column()->size());
for (auto k = 0; k < msg->column()->size(); ++k) {
auto col_ts = msg->column()->Get(k);
std::shared_ptr<Tensor> ts;
ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts));
row.push_back(ts);
ts_offset += data.GetSize();
}
}
tbl.push_back(std::move(row));
}
*out = std::move(tbl);
return Status::OK();
}
Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
v.reserve(map.size());
for (auto &column : map) {
auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second);
v.push_back(c);
}
auto v_off = fbb_->CreateVector(v);
auto final_off = CreateSchemaMsg(*fbb_, v_off);
fbb_->Finish(final_off);
buf_ = fbb_->GetBufferPointer();
len_of_buf_ = fbb_->GetSize();
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() {
if (column_name_id_map_.empty()) {
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer());
auto v = map_msg->column();
for (auto i = 0; i < v->size(); ++i) {
auto col = map_msg->column()->Get(i);
column_name_id_map_.emplace(col->name()->str(), col->id());
}
}
return column_name_id_map_;
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,225 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_REQ_H_
#define DATASET_ENGINE_CACHE_REQ_H_
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/tensor_row.h"
#include "dataset/util/slice.h"
#include "dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
/// \brief CacheClient communicates with CacheServer using Requests.
class BaseRequest {
public:
// Request types
enum class RequestType : int16_t {
kCacheRow = 0,
kBatchFetchRows = 1,
kCreateCache = 2,
kPurgeCache = 3,
kDestroyCache = 4,
kGetStat = 5,
kCacheSchema = 6,
kFetchSchema = 7,
kBuildPhaseDone = 8,
// Add new request before it.
kRequestUnknown = 32767
};
// For kCreateCache
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
friend class CacheServer;
/// \brief Base class of a cache server request
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
/// \param type Type of the request
explicit BaseRequest(connection_id_type connection_id, RequestType type)
: type_(type), connection_id_(connection_id) {}
virtual ~BaseRequest() = default;
/// \brief Wait for the completion of a request
/// \return Status returned from the cache server
Status Wait() {
RETURN_IF_NOT_OK(wp_.Wait());
return rc_;
}
/// \brief Getter function of the current connection id
/// \return Connection id
connection_id_type GetServerConnectionId() const { return connection_id_; }
private:
RequestType type_;
connection_id_type connection_id_;
Status rc_;
WaitPost wp_;
};
/// \brief Request to cache a single TensorRow
class CacheRowRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {}
~CacheRowRequest() = default;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
/// \return Status object
Status SerializeCacheRowRequest(const TensorRow &row);
/// \brief Return the row id assigned to this row for non-mappable dataset
/// \return row id of the cached row
row_id_type GetRowIdAfterCache() { return row_id_from_server_; }
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
row_id_type row_id_from_server_;
std::vector<const void *> buffers_;
std::string cookie_;
/// \brief Private function to serialize one TensorRow
/// \param row TensorRow
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row);
/// \brief Private function to serialize one Tensor
/// \param ts_ptr Tensor
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off);
};
/// \brief Request to fetch rows in batch
class BatchFetchRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id)
: BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {}
Status RestoreRows(TensorTable *out);
private:
std::vector<row_id_type> row_id_;
MemGuard<uint8_t> mem_;
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
};
/// \brief Request to create a cache for the current connection
class CreationCacheRequest : public BaseRequest {
public:
friend class CacheServer;
/// \brief Constructor
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone)
: BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {}
std::string cookie() const { return cookie_; }
private:
uint64_t cache_mem_sz;
CreateCacheFlag flag_;
std::string cookie_;
};
/// \brief Request to purge a cache.
class PurgeCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {}
};
/// \brief Request to destroy a cache
class DestroyCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit DestroyCacheRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kDestroyCache) {}
};
/// \brief Obtain the statistics of the current connection
class GetStatRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {}
row_id_type GetMinRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->min_row_id();
}
row_id_type GetMaxRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->max_row_id();
}
int64_t GetNumMemCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_mem_cached();
}
int64_t GetNumDiskCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_disk_cached();
}
uint8_t GetState() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->state();
}
private:
MemGuard<uint8_t> mem_;
};
/// \brief Request to cache a schema
class CacheSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {}
~CacheSchemaRequest() = default;
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
const void *GetBuffer() const { return buf_; }
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
const void *buf_;
int64_t len_of_buf_;
};
/// \brief Request to fetch a schema
class FetchSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FetchSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kFetchSchema) {}
~FetchSchemaRequest() = default;
std::unordered_map<std::string, int32_t> GetColumnMap();
private:
MemGuard<uint8_t> mem_;
std::unordered_map<std::string, int32_t> column_name_id_map_;
};
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
class BuildPhaseDoneRequest : public BaseRequest {
public:
friend class CacheServer;
BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {}
private:
std::string cookie_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_

File diff suppressed because it is too large Load Diff

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVER_H_
#define DATASET_ENGINE_CACHE_SERVER_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include "dataset/engine/cache/cache_service.h"
#include "dataset/core/tensor.h"
#include "dataset/util/arena.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/lock.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
#include "dataset/util/queue.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
class BaseRequest;
/// \brief A server which provides CacheService services.
class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
CacheServer(const CacheServer &) = delete;
CacheServer &operator=(const CacheServer &) = delete;
CacheServer(CacheServer &&) = delete;
CacheServer &operator=(CacheServer &) = delete;
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
Status DoServiceStart() override;
Status DoServiceStop() override;
~CacheServer() { (void)ServiceStop(); }
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq
/// \return Status object
Status PushRequest(BaseRequest *rq) {
RETURN_UNEXPECTED_IF_NULL(rq);
RETURN_IF_NOT_OK(cache_q_->Add(rq));
return Status::OK();
}
private:
mutable RWLock rwLock_;
std::string top_;
cache_index all_caches_;
std::shared_ptr<Queue<BaseRequest *>> cache_q_;
TaskGroup vg_;
int32_t num_workers_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3);
/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
CacheService *GetService(connection_id_type id) const;
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag,
std::string *out_cookie);
/// \brief Entry point for all server threads.
Status ServerRequest();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_CORE_CACHE_TENSOR_H_

File diff suppressed because it is too large Load Diff

@ -0,0 +1,143 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVICE_H_
#define DATASET_ENGINE_CACHE_SERVICE_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/arena.h"
#include "dataset/util/btree.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
namespace mindspore {
namespace dataset {
struct CacheStat;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class CacheService : public Service {
public:
friend class CacheServer;
using row_map = BPlusTree<row_id_type, CachePool::key_type>;
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase };
/// \brief Constructor
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
/// \param root Spill path. Empty string means no spilling
/// \param generate_id If the cache service should generate row id for buffer that is cached.
/// For non-mappable dataset, this should be set to true.
CacheService(uint64_t mem_sz, const std::string &root, bool generate_id);
~CacheService();
/// \brief For fixed size memory, we will create an Arena.
/// \return false if unlimited memory.
bool UseArena();
Status DoServiceStart() override;
Status DoServiceStop() override;
/// \brief Main function to cache a row which is in form a series of buffers.
/// The first buffer is a Google flatbuffer which describes the rest of the buffers followed.
/// \param[in] buf Vector of buffer
/// \param[out] row_id_generated The row id assigned to this row if any
/// \return Status object
Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const;
/// \brief Getter function
/// \return Spilling path
Path GetSpillPath() const;
/// \brief A structure returned from the cache server for statistics request.
class ServiceStat {
public:
using state_type = std::underlying_type<State>::type;
ServiceStat() : min_(0), max_(0), state_(0) {}
CachePool::CacheStat stat_{};
row_id_type min_;
row_id_type max_;
state_type state_;
};
/// \brief Statistics for the current service
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
/// \return Status Object
Status GetStat(ServiceStat *);
/// \brief Cache schema
/// \param buf A Google Flatbuffer that contains the schema
/// \param len size of the buffer
/// \return Status object
Status CacheSchema(const void *buf, int64_t len);
/// \brief Fetch schema
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status FetchSchema(MemGuard<uint8_t> *out) const;
/// \brief Purge the content of a cache
/// \return Status object
Status Purge();
/// \brief Overload the << operator to print a cache service
/// \param out std::ostream
/// \param cs A cache service
/// \return std::ostream
friend std::ostream &operator<<(std::ostream &out, const CacheService &cs);
/// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient
/// is the creator
/// \return Cookie
std::string cookie() const { return cookie_; }
/// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and
/// a read phase.
/// \return True if has two phases.
bool HasBuildPhase() const { return generate_id_; }
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
/// \return Status object
Status BuildPhaseDone();
private:
mutable RWLock rw_lock_;
std::string root_;
uint64_t cache_mem_sz_;
std::shared_ptr<CachePool> cp_;
std::shared_ptr<row_map> map_;
std::atomic<row_id_type> next_id_;
bool generate_id_;
std::atomic<CachePool::key_type> schema_key_;
std::string cookie_;
State st_;
/// \brief Private function to generate a row id
/// \return Row id assigned.
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_

@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace mindspore.dataset;
/// Type of a Tensor
enum TensorType : byte {
DE_UNKNOWN = 0,
DE_BOOL = 1,
DE_INT8 = 2,
DE_UINT8 = 3,
DE_INT16 = 4,
DE_UINT16 = 5,
DE_INT32 = 6,
DE_UINT32 = 7,
DE_INT64 = 8,
DE_UINT64 = 9,
DE_FLOAT16 = 10,
DE_FLOAT32 = 11,
DE_FLOAT64 = 12,
DE_STRING = 13
}
/// The meta information of a Tensor
/// \note Only the type and shape are considered meta information. Tensor data is excluded.
table TensorMetaMsg {
dims:[int64] (required);
type:TensorType;
}
/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized.
/// \param row_id is the row id of the TensorRow.
/// \param column The meta information of each Tensor in the row
/// \param size of this serialized buffer
/// \param size of each tensor data buffer that follows
table TensorRowHeaderMsg {
row_id:int64;
column:[TensorMetaMsg] (required);
size_of_this:int64;
data_sz:[int64] (required);
}
root_type TensorRowHeaderMsg;
/// A row of row id's
table TensorRowIds {
row_id:[int64] (required);
}
/// Statistics returned from each cache service
/// \note It must match CacheService::ServiceStat
table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
}
/// Column description of each column in a schema
table ColumnNameMsg {
name:string;
id:int32;
}
/// Serialized form of a schema
table SchemaMsg {
column:[ColumnNameMsg];
}

@ -24,10 +24,8 @@ namespace dataset {
// Description: This is the main constructor that is used for making a buffer
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
bool show_all) const { // In: T/F if it should show everything
// A method for debug printing of the buffer
void DataBuffer::Print(std::ostream &out, bool show_all) const {
out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n";
// If the column counts are set then it means that data has been set into
@ -46,11 +44,6 @@ void DataBuffer::Print(std::ostream &out, // In: The output stream to print
}
}
Status DataBuffer::Load() {
std::string err_msg = "Base class load called, but it does not have an implementation!";
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Remove me!! Callers should fetch rows via pop
Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32_t col_id) const {
if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) {
@ -92,8 +85,5 @@ Status DataBuffer::SliceOff(int64_t number_of_rows) {
return Status::OK();
}
// Destructor
DataBuffer::~DataBuffer() {}
} // namespace dataset
} // namespace mindspore

@ -29,11 +29,9 @@
namespace mindspore {
namespace dataset {
// The DataBuffer class is a base class that will represent the data for n values based
// on a unique row id for each row of data.
// There can be different types of DataBuffers to abstract over how the data is stored
// in memory and acquired from storage.
// Each buffer holds a range of consecutive row id's.
/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between
/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format
/// where n TensorRows may consist of m tensors (columns).
class DataBuffer {
public:
// Buffer flags
@ -47,13 +45,13 @@ class DataBuffer {
// Description: This is the main constructor that is used for making a buffer
DataBuffer(int32_t id, BufferFlags flags);
// Destructor
virtual ~DataBuffer();
/// \brief default destructor
~DataBuffer() = default;
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
virtual void Print(std::ostream &out, // In: The output stream to print to
bool show_all) const; // In: T/F if it should show everything
/// \brief A method for debug printing of the buffer
/// \param[inout] out The stream to write to
/// \param[in] show_all A boolean to toggle between details and summary printing
void Print(std::ostream &out, bool show_all) const;
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) {
@ -61,10 +59,6 @@ class DataBuffer {
return out;
}
// Name: load()
// Description: populates the DataBuffer with data based on it's id
virtual Status Load();
// Convenience getter functions for flag checking
bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); }

@ -17,7 +17,11 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
take_op.cc
shuffle_op.cc
zip_op.cc
concat_op.cc
concat_op.cc
cache_base_op.cc
cache_lookup_op.cc
cache_op.cc
cache_merge_op.cc
)
if (ENABLE_PYTHON)

@ -0,0 +1,185 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_base_op.h"
#include <iomanip>
#include <iostream>
#include "dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
// A print method typically used for debugging
void CacheBase::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCache client:\n" << *cache_client_ << "\n\n";
}
}
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
Status CacheBase::Reset() {
if (sampler_ != nullptr) {
RETURN_IF_NOT_OK(sampler_->ResetSampler());
}
// Wake up the workers to get them going again in a new epoch
MS_LOG(DEBUG) << Name() << " resetting.";
epoch_sync_.Set();
return Status::OK();
}
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, sampler),
cache_client_(cache_client),
rows_per_buffer_(rows_per_buf),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_(num_workers_, 1, 1024) {
io_block_queues_.Init(num_workers, op_connector_size);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status CacheBase::FetchSamplesToWorkers() {
int64_t buf_cnt = 0;
int64_t wait_cnt = 0;
do {
epoch_sync_.Clear();
std::vector<row_id_type> keys;
int64_t row_cnt = 0;
keys.reserve(rows_per_buffer_);
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (!sampler_buffer->eoe()) {
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++row_cnt;
if (row_cnt % rows_per_buffer_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
}
// send the eoe
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
// If repeat but the not last repeat, wait for reset.
if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
RETURN_IF_NOT_OK(epoch_sync_.Wait());
} else {
// We can break out from the loop.
break;
}
} while (true);
// Flow the eof before exit
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
// Ask all the workers to quit.
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
return Status::OK();
}
Status CacheBase::FetchFromCache(int32_t worker_id) {
int64_t buffer_id = worker_id;
std::unique_ptr<IOBlock> blk;
do {
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
if (blk->eof()) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
} else if (blk->eoe()) {
if (AllowCacheMiss()) {
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
// a sampler, send a eoe to physical leaf op as well.
std::vector<row_id_type> eoe;
eoe.push_back(eoe_row_id);
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
} else {
std::vector<int64_t> keys;
RETURN_IF_NOT_OK(blk->GetKeys(&keys));
if (keys.empty()) {
// empty key is a quit signal for workers
break;
}
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl));
auto row_it = ttbl.begin();
std::vector<row_id_type> cache_miss;
cache_miss.reserve(keys.size());
for (auto row_id : keys) {
auto &row = *row_it;
if (row.empty()) {
if (AllowCacheMiss()) {
cache_miss.push_back(row_id);
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
que->push_back(std::move(row));
++row_it;
}
db->set_tensor_table(std::move(que));
if (AllowCacheMiss()) {
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
buffer_id += num_workers_;
}
} while (true);
return Status::OK();
}
Status CacheBase::RegisterResources() {
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
return Status::OK();
}
CacheBase::~CacheBase() {}
Status CacheBase::UpdateColumnMapFromCache() {
Status rc;
// Get the schema from the server. It may not be there yet. So tolerate the error.
if (column_name_id_map_.empty()) {
rc = cache_client_->FetchSchema(&column_name_id_map_);
if (rc == Status(StatusCode::kFileNotExist)) {
MS_LOG(DEBUG) << "Schema not in the server yet.";
rc = Status::OK();
}
}
return rc;
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,108 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_service.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/util/queue.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
/// \see CacheOp
/// \see CacheLookupOp
class CacheBase : public ParallelOp {
public:
/// \brief Base class constructor
/// \param num_workers Number of parallel workers
/// \param op_connector_size Connector size
/// \param rows_per_buf Number of rows per buffer
/// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
/// \brief Destructor
~CacheBase();
constexpr static int eoe_row_id = -1;
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
/// info from it's previous execution and then initializes itself so that it can be executed
/// again.
/// \return Status - The error code return
Status Reset() override;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
/// \param show_all A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out reference to the output stream being overloaded
/// \param mo reference to the CacheOp to display
/// \return the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) {
mo.Print(out, false);
return out;
}
/// \brief Getter for the cache client
/// \return shared ptr to the cache client
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
/// \brief Setter for the cache client
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
/// \brief Derived class must implement this method if a cache miss is treated as error
virtual bool AllowCacheMiss() = 0;
protected:
std::shared_ptr<CacheClient> cache_client_;
WaitPost epoch_sync_;
int32_t rows_per_buffer_;
Connector<std::vector<row_id_type>> keys_miss_;
/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
virtual Status RegisterResources();
/// \brief This function is called by main thread to send samples to the worker thread.
/// \note It is a non-virtual function
/// \return Status object
Status FetchSamplesToWorkers();
/// \brief This function is called by each worker to fetch rows from the cache server for a given set of
/// sample row id's
/// \return Status object
Status FetchFromCache(int32_t worker_id);
/// \brief Get the column map from cache server
Status UpdateColumnMapFromCache();
private:
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_

@ -0,0 +1,130 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/execution_tree.h"
#include "utils/log_adapter.h"
#include "utils/system/crc32c.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheLookupOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Cache client for CacheLookupOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_,
build_cache_client_, build_sampler_);
return Status::OK();
}
Status CacheLookupOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"CacheLookupOp requires a sampler before it can be executed!");
}
RETURN_IF_NOT_OK(RegisterResources());
// Kick off the workers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1)));
// required task group sync after launching workers
TaskManager::FindMe()->Post();
// We have to wait until the leaf op has handshake with us.
RETURN_IF_NOT_OK(leaf_op_wp_.Wait());
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
return Status::OK();
}
Status CacheLookupOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
return Status::OK();
}
Status CacheLookupOp::ResetSampler() { return Status::OK(); }
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
// We act like a sampler and as a dataset op. During handshake with leaf op,
// We must wait until the leaf op has indexed everything.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op));
// Now we notify the main thread handshake has finished.
leaf_op_wp_.Set();
return Status::OK();
}
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
std::vector<row_id_type> cache_miss;
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
// Ignore the case we have no cache miss, we can't return empty samples.
while (cache_miss.empty()) {
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
}
// Special code for eoe
if (cache_miss.at(0) == eoe_row_id) {
*out_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
std::shared_ptr<Tensor> sample_ts;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size()));
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
auto idPtr = sample_ts->begin<int64_t>();
for (auto i = 0; i < cache_miss.size(); ++i) {
*idPtr = cache_miss.at(i);
++idPtr;
}
TensorRow row;
row.push_back(sample_ts);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
return Status::OK();
}
Status CacheLookupOp::RegisterResources() {
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks()));
return Status::OK();
}
Status CacheLookupOp::ComputeColMap() {
// We don't know the column map at this point unless we contact the cache server
// to fetch the schema but the cache server may not have it at this point either.
// So we will just return OK and let MergeOp (our parent) to handle it.
return Status::OK();
}
// Visitor accept method for NodePass
Status CacheLookupOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheLookupOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
/// \note For non-mappable dataset, please see CacheOp
/// \see CacheOp
class CacheLookupOp : public CacheBase, public Sampler {
public:
class Builder {
public:
/// \brief Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \treturn Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheLookupOp object
/// \return Status
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
private:
int32_t build_num_workers_;
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
// Check if the required parameters are set by the builder.
// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor
/// \note It takes the same argument as the base class.
/// \see CacheBase
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
~CacheLookupOp() = default;
// As a parallel op, we override these two functions
Status operator()() override;
Status WorkerEntry(int32_t worker_id) override;
// As a sampler, we override the following functions
Status ResetSampler() override;
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
Status InitSampler() override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
void Print(std::ostream &out, bool show_all) const override;
bool AllowCacheMiss() override { return true; }
std::string Name() const override { return "CacheLookupOp"; }
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
protected:
Status ComputeColMap() override;
private:
WaitPost leaf_op_wp_;
Status RegisterResources() override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_

File diff suppressed because it is too large Load Diff

@ -0,0 +1,196 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include "dataset/core/tensor_row.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/util/queue.h"
#include "dataset/util/semaphore.h"
namespace mindspore {
namespace dataset {
/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
/// stream
class CacheMergeOp : public ParallelOp {
public:
// Some handshake structures among the main thread, cleaner threads and parallel op threads.
class TensorRowRequest {
public:
enum class State : uint8_t {
kEmpty = 0, // No row in the deque
kDirty = 1, // Cleaner hasn't flushed it to the cache server yet.
kClean = 2 // The row has been flushed already.
};
explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {}
~TensorRowRequest() = default;
State GetState() const { return st_; }
void SetState(State newState) { st_ = newState; }
Status Wait(TensorRow *out);
void WakeUpAny(TensorRow &&row);
Status Release(TensorRow *out);
private:
std::mutex dq_mux_;
std::atomic<State> st_;
Semaphore use_count_;
std::deque<TensorRow> row_;
TensorRow cleaner_copy_;
};
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class Builder {
public:
/// Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief Setter method
/// \param num_cleaners
/// \return Builder setter method returns reference to the builder.
Builder &SetNumCleaner(int32_t num_cleaners) {
build_num_cleaners_ = num_cleaners;
return *this;
}
/// The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheMergeOp object
/// \return Status
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
private:
int32_t build_num_workers_;
int32_t build_op_connector_size_;
int32_t build_num_cleaners_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
/// Check if the required parameters are set by the builder.
/// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
/// \param opConnector Size Connector size as a derived class of ParallelOp
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
/// \param cache_client CacheClient to commmunicate with the Cache server
/// \param sampler as a derived class of ParallelOp
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
~CacheMergeOp();
void Print(std::ostream &out, bool show_all) const override;
friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) {
mo.Print(out, false);
return out;
}
/// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and
/// the threads for the cleaners.
/// \return
Status operator()() override;
/// \brief Entry function for worker thread that fetch rows from CacheLookupOp
/// \param workerId
/// \return Status object
Status WorkerEntry(int32_t workerId) override;
Status PrepareNodePostAction() override;
/// \brief Entry function for worker thread that fetch rows from the cache miss stream
/// \param workerId
/// \return Status object
Status CacheMissWorkerEntry(int32_t workerId);
Status GetRq(row_id_type row_id, TensorRowRequest **);
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for eoe handling
/// \param worker_id
/// \return Status object
Status EoeReceived(int32_t worker_id) override;
protected:
Status ComputeColMap() override;
private:
std::mutex mux_;
std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_;
std::unique_ptr<Queue<row_id_type>> io_que_;
std::shared_ptr<CacheClient> cache_client_;
int32_t num_cleaners_;
/// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
/// moving cache miss TensorRow into the CacheServer.
/// \return Status object
Status Cleaner();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_

@ -0,0 +1,219 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_op.h"
#include <memory>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_,
build_sampler_);
RETURN_IF_NOT_OK((*ptr)->InitCache());
return Status::OK();
}
// Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler),
num_guys_in_(0),
phase_(Phase::kBuildPhase) {}
// Destructor
CacheOp::~CacheOp() = default;
// Private function for cache setup/init work just after construction
Status CacheOp::InitCache() { return Status::OK(); }
// This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"CacheOp requires a sampler before it can be executed!");
}
RETURN_IF_NOT_OK(RegisterResources());
// Kick off the workers
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1)));
// required task group sync after launching workers
TaskManager::FindMe()->Post();
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK(WaitForCachingAllRows());
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
return Status::OK();
}
Status CacheOp::CacheAllRows(int32_t worker_id) {
// If the current phase is to fill the cache, do it then.
if (phase_ == Phase::kBuildPhase) {
// We will take the chance to cache the schema at the server.
// Just do it once and pick one worker to do it.
if (worker_id == 0) {
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
}
MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id;
// SAVE mode loop
std::unique_ptr<DataBuffer> db_ptr;
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
while (!db_ptr->eof()) {
if (!db_ptr->eoe()) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
} else {
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
// the eoe to indicate the end of the epoch, we should next expect to get the eof.
// Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch
// from again.
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
if (!db_ptr->eof()) {
RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child.");
}
}
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
}
}
// Let the main guy know we are done.
auto last_guy_in = num_guys_in_.fetch_add(1);
if ((last_guy_in + 1) == num_workers_) {
rows_cache_done_.Set();
} else {
// Let's do a sync up here.
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
}
return Status::OK();
}
Status CacheOp::WaitForCachingAllRows() {
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
// Move from build phase to fetch phase if we are the one to fill the cache
if (phase_ == Phase::kBuildPhase) {
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
// Move to the next phase
phase_ = Phase::kFetchPhase;
}
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheClient::ServiceStat stat{};
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase);
if (!BuildPhaseDone) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
} while (!BuildPhaseDone);
const row_id_type min_key = stat.min_row_id;
const row_id_type max_key = stat.max_row_id;
num_rows_ = max_key - min_key + 1;
MS_LOG(INFO) << "Number of rows cached: " << num_rows_;
MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached;
MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached;
// Now all rows are cached and we have done a sync point check up. Next phase is
// is pick up fetch input from sampler and pass up to the caller.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK();
}
Status CacheOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(CacheAllRows(worker_id));
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
return Status::OK();
}
Status CacheOp::RegisterResources() {
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks()));
return Status::OK();
}
// Base-class override for setting specific CacheOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; }
// Base-class override for special eoe handler.
// CacheOp must override this because it shall not perform default handling of eoe. Instead
// the CacheOp manages actions related to the end of the epoch.
Status CacheOp::EoeReceived(int32_t worker_id) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
// Base-class override for handling cases when an eof is received.
Status CacheOp::EofReceived(int32_t worker_id) {
// eofReceived is overloaded because we want to manually handle this eof.
// Specifically, the default behaviour is to pack it and flow it up to the next connection.
// In this case, we want a no-op behaviour so that we can perform correct action.
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status CacheOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<CacheOp>(), modified);
}
// Visitor accept method for NodePass
Status CacheOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
}
// A public wrapper for creating the cache through the client
Status CacheOp::CreateCache(uint32_t cache_crc) {
// This is a non-mappable cache op so the id's need to be generated.
// Construct the cache
const bool generate_ids = true;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) {
// We are told the cache has been created already. So we skip the build phase.
phase_ = Phase::kFetchPhase;
rc = Status::OK();
}
RETURN_IF_NOT_OK(rc);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save