commit
eadcb341e1
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,8 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-cache-client OBJECT
|
||||
cache_client.cc
|
||||
cache_request.cc)
|
||||
add_library(engine-cache-server OBJECT
|
||||
cache_service.cc
|
||||
cache_server.cc)
|
@ -0,0 +1,208 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iomanip>
|
||||
#include "dataset/engine/cache/cache_client.h"
|
||||
#include "dataset/engine/cache/cache_request.h"
|
||||
#include "dataset/util/bit.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor
|
||||
CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill)
|
||||
: server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {}
|
||||
|
||||
// print method for display cache details
|
||||
void CacheClient::Print(std::ostream &out) const {
|
||||
out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_
|
||||
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_
|
||||
<< "\n Spilling: " << std::boolalpha << spill_;
|
||||
}
|
||||
|
||||
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
|
||||
CacheRowRequest rq(server_connection_id_, cookie());
|
||||
RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row));
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
RETURN_IF_NOT_OK(rq.Wait());
|
||||
if (row_id_from_server != nullptr) {
|
||||
*row_id_from_server = rq.GetRowIdAfterCache();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
|
||||
std::unique_ptr<DataBuffer> db_ptr = std::move(in);
|
||||
auto num_rows = db_ptr->NumRows();
|
||||
std::vector<TensorRow> all_rows;
|
||||
if (num_rows > 0) {
|
||||
all_rows.reserve(num_rows);
|
||||
// Break down the DataBuffer into TensorRow. We will send the requests async
|
||||
// and then do a final wait.
|
||||
MemGuard<CacheRowRequest> rq_arr;
|
||||
RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie()));
|
||||
CacheServer &cs = CacheServer::GetInstance();
|
||||
for (auto i = 0; i < num_rows; ++i) {
|
||||
TensorRow row;
|
||||
auto rq = rq_arr[i];
|
||||
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
|
||||
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row));
|
||||
RETURN_IF_NOT_OK(cs.PushRequest(rq));
|
||||
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
|
||||
// So park it in the vector. When this function go out of scope, its memory
|
||||
// will be freed.
|
||||
all_rows.push_back(std::move(row));
|
||||
}
|
||||
// Now we wait for the requests to be done.
|
||||
for (auto i = 0; i < num_rows; ++i) {
|
||||
auto rq = rq_arr[i];
|
||||
RETURN_IF_NOT_OK(rq->Wait());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
BatchFetchRequest rq(server_connection_id_, row_id);
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
RETURN_IF_NOT_OK(rq.Wait());
|
||||
RETURN_IF_NOT_OK(rq.RestoreRows(out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
|
||||
UniqueLock lck(&mux_);
|
||||
// To create a cache, we identify ourself at the client by:
|
||||
// - the shared session id
|
||||
// - a crc for the tree nodes from the cache downward
|
||||
// Pack these 2 into a single 64 bit request id
|
||||
//
|
||||
// Consider this example:
|
||||
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
|
||||
// tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch
|
||||
// These are different trees in a single session, but the user wants to share the cache.
|
||||
// This is not allowed because the data of these caches are different.
|
||||
//
|
||||
// Consider this example:
|
||||
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
|
||||
// tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch
|
||||
// These are different trees in the same session, but the cached data is the same, so it is okay
|
||||
// to allow the sharing of this cache between these pipelines.
|
||||
|
||||
// The CRC is computed by the tree prepare phase and passed to this function when creating the cache.
|
||||
// If we already have a server_connection_id_, then it means this same cache client has already been used
|
||||
// to create a cache and some other tree is trying to use the same cache.
|
||||
// That is allowed, however the crc better match!
|
||||
if (server_connection_id_) {
|
||||
if (cache_crc_ != tree_crc) {
|
||||
RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!");
|
||||
}
|
||||
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
|
||||
// skip the build phase.
|
||||
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
|
||||
CacheClient::ServiceStat stat{};
|
||||
RETURN_IF_NOT_OK(GetStat(&stat));
|
||||
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
|
||||
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
|
||||
}
|
||||
} else {
|
||||
cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client
|
||||
// Combine the session and crc. This will form our client cache identifier.
|
||||
connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_;
|
||||
// Now execute the cache create request using this identifier and other configs
|
||||
BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone;
|
||||
if (spill_) {
|
||||
createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk;
|
||||
}
|
||||
if (generate_id) {
|
||||
createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId;
|
||||
}
|
||||
CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag);
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
Status rc = rq.Wait();
|
||||
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
|
||||
server_connection_id_ = rq.GetServerConnectionId();
|
||||
if (rc.IsOk()) {
|
||||
// The 1st guy creating the cache will get a cookie back.
|
||||
// But this object may be shared among pipelines and we don't want
|
||||
// overwrite it.
|
||||
cookie_ = rq.cookie();
|
||||
}
|
||||
}
|
||||
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
|
||||
// CacheOp to bypass the build phase.
|
||||
return rc;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::PurgeCache() {
|
||||
UniqueLock lck(&mux_);
|
||||
PurgeCacheRequest rq(server_connection_id_);
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
return rq.Wait();
|
||||
}
|
||||
|
||||
Status CacheClient::DestroyCache() {
|
||||
UniqueLock lck(&mux_);
|
||||
DestroyCacheRequest rq(server_connection_id_);
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
return rq.Wait();
|
||||
}
|
||||
|
||||
Status CacheClient::GetStat(ServiceStat *stat) {
|
||||
SharedLock lck(&mux_);
|
||||
RETURN_UNEXPECTED_IF_NULL(stat);
|
||||
GetStatRequest rq(server_connection_id_);
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
RETURN_IF_NOT_OK(rq.Wait());
|
||||
stat->num_disk_cached = rq.GetNumDiskCached();
|
||||
stat->num_mem_cached = rq.GetNumMemCached();
|
||||
stat->min_row_id = rq.GetMinRowId();
|
||||
stat->max_row_id = rq.GetMaxRowId();
|
||||
stat->cache_service_state = rq.GetState();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
|
||||
SharedLock lck(&mux_);
|
||||
CacheSchemaRequest rq(server_connection_id_);
|
||||
RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map));
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
RETURN_IF_NOT_OK(rq.Wait());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
|
||||
SharedLock lck(&mux_);
|
||||
RETURN_UNEXPECTED_IF_NULL(map);
|
||||
FetchSchemaRequest rq(server_connection_id_);
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
RETURN_IF_NOT_OK(rq.Wait());
|
||||
*map = rq.GetColumnMap();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::BuildPhaseDone() const {
|
||||
SharedLock lck(&mux_);
|
||||
BuildPhaseDoneRequest rq(server_connection_id_, cookie());
|
||||
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||
RETURN_IF_NOT_OK(rq.Wait());
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,141 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_CACHE_CLIENT_H_
|
||||
#define DATASET_ENGINE_CACHE_CLIENT_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "./de_tensor_generated.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/cache/cache_server.h"
|
||||
#include "dataset/util/lock.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through
|
||||
/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously
|
||||
/// rows, etc.
|
||||
class CacheClient {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param session_id A user assigned session id for the current pipeline
|
||||
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
|
||||
/// \param spill Spill to disk if out of memory
|
||||
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheClient() = default;
|
||||
|
||||
/// \brief Getter function for returning the current session id
|
||||
/// \return session id
|
||||
uint64_t session_id() const { return session_id_; }
|
||||
|
||||
/// \brief Send a TensorRow to the cache server
|
||||
/// \param[in] row
|
||||
/// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset
|
||||
/// \return return code
|
||||
Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const;
|
||||
|
||||
/// \brief Send a DataBuffer to the cache server
|
||||
/// \param in Unique pointer of the DataBuffer to be cached
|
||||
/// \return return code
|
||||
Status WriteBuffer(std::unique_ptr<DataBuffer> &&in) const;
|
||||
|
||||
/// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is
|
||||
/// any cache miss
|
||||
/// \param row_id A vector of row id's
|
||||
/// \param out A TensorTable of TensorRows.
|
||||
/// \return return code
|
||||
Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const;
|
||||
|
||||
/// \brief Create a cache.
|
||||
/// \param tree_crc A crc that was generated during tree prepare phase
|
||||
/// \param generate_id Let the cache service generate row id
|
||||
/// \return Status object
|
||||
Status CreateCache(uint32_t tree_crc, bool generate_id);
|
||||
|
||||
/// \brief Purge a cache. Cache can be reused after reset.
|
||||
/// \return Status object
|
||||
Status PurgeCache();
|
||||
|
||||
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
|
||||
/// \return Status object
|
||||
Status DestroyCache();
|
||||
|
||||
/// \brief Get the statistics from a cache.
|
||||
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
|
||||
/// \return Status object
|
||||
struct ServiceStat {
|
||||
int64_t num_mem_cached;
|
||||
int64_t num_disk_cached;
|
||||
row_id_type min_row_id;
|
||||
row_id_type max_row_id;
|
||||
int8_t cache_service_state;
|
||||
};
|
||||
Status GetStat(ServiceStat *);
|
||||
|
||||
/// \brief Cache the schema at the cache server
|
||||
/// \param map The unordered map of the schema
|
||||
/// \return Status object
|
||||
Status CacheSchema(const std::unordered_map<std::string, int32_t> &map);
|
||||
|
||||
/// \brief Fetch the schema from the cache server
|
||||
/// \param map Pointer to pre-allocated map object
|
||||
/// \return Status object.
|
||||
Status FetchSchema(std::unordered_map<std::string, int32_t> *map);
|
||||
|
||||
/// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache
|
||||
/// client that holds cookie can be allowed to make this request
|
||||
/// \return Status object
|
||||
Status BuildPhaseDone() const;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
/// \param out The output stream to write output to
|
||||
void Print(std::ostream &out) const;
|
||||
|
||||
/// \brief Stream output operator overload
|
||||
/// \return the output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) {
|
||||
cc.Print(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
/// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it.
|
||||
/// \return Cookie
|
||||
std::string cookie() const { return cookie_; }
|
||||
|
||||
private:
|
||||
mutable RWLock mux_;
|
||||
uint64_t cache_mem_sz_;
|
||||
bool spill_;
|
||||
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
|
||||
// sharing of the cache.
|
||||
uint32_t session_id_;
|
||||
uint32_t cache_crc_;
|
||||
// The server_connection_id_ is the actual id we use for operations after the cache is built
|
||||
connection_id_type server_connection_id_;
|
||||
// Some magic cookie returned from the cache server.
|
||||
std::string cookie_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_CACHE_CLIENT_H_
|
@ -0,0 +1,223 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/cache/cache_request.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) {
|
||||
buffers_.reserve(row.size() + 1);
|
||||
RETURN_IF_NOT_OK(SerializeTensorRowHeader(row));
|
||||
buffers_.push_back(fbb_->GetBufferPointer());
|
||||
for (const auto &ts : row) {
|
||||
buffers_.push_back(ts->GetBuffer());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) {
|
||||
try {
|
||||
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
|
||||
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
|
||||
std::vector<int64_t> tensor_sz;
|
||||
v.reserve(row.size());
|
||||
tensor_sz.reserve(row.size());
|
||||
// We will go through each column in the row.
|
||||
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
|
||||
flatbuffers::Offset<TensorMetaMsg> ts_off;
|
||||
RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off));
|
||||
v.push_back(ts_off);
|
||||
tensor_sz.push_back(ts_ptr->SizeInBytes());
|
||||
}
|
||||
auto column_off = fbb_->CreateVector(v);
|
||||
auto data_sz_off = fbb_->CreateVector(tensor_sz);
|
||||
TensorRowHeaderMsgBuilder row_builder(*fbb_);
|
||||
row_builder.add_column(column_off);
|
||||
row_builder.add_data_sz(data_sz_off);
|
||||
// Pass the row_id even if it may not be known.
|
||||
row_builder.add_row_id(row.getId());
|
||||
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
|
||||
auto out = row_builder.Finish();
|
||||
fbb_->Finish(out);
|
||||
// Now go back to fill in size_of_this in the flat buffer.
|
||||
auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer());
|
||||
auto success = msg->mutate_size_of_this(fbb_->GetSize());
|
||||
if (!success) {
|
||||
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
|
||||
}
|
||||
return Status::OK();
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
|
||||
Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr,
|
||||
flatbuffers::Offset<TensorMetaMsg> *out_off) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_off);
|
||||
const Tensor *ts = ts_ptr.get();
|
||||
auto shape_off = fbb_->CreateVector(ts->shape().AsVector());
|
||||
const auto ptr = ts->GetBuffer();
|
||||
if (ptr == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
|
||||
}
|
||||
auto src = ts->type().value();
|
||||
TensorType dest;
|
||||
#define CASE(t) \
|
||||
case DataType::t: \
|
||||
dest = TensorType::TensorType_##t; \
|
||||
break
|
||||
// Map the type to fill in the flat buffer.
|
||||
switch (src) {
|
||||
CASE(DE_BOOL);
|
||||
CASE(DE_INT8);
|
||||
CASE(DE_UINT8);
|
||||
CASE(DE_INT16);
|
||||
CASE(DE_UINT16);
|
||||
CASE(DE_INT32);
|
||||
CASE(DE_UINT32);
|
||||
CASE(DE_INT64);
|
||||
CASE(DE_UINT64);
|
||||
CASE(DE_FLOAT16);
|
||||
CASE(DE_FLOAT32);
|
||||
CASE(DE_FLOAT64);
|
||||
CASE(DE_STRING);
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
|
||||
RETURN_STATUS_UNEXPECTED("Unknown type");
|
||||
}
|
||||
#undef CASE
|
||||
|
||||
TensorMetaMsgBuilder ts_builder(*fbb_);
|
||||
ts_builder.add_dims(shape_off);
|
||||
ts_builder.add_type(dest);
|
||||
auto ts_off = ts_builder.Finish();
|
||||
*out_off = ts_off;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(col_ts);
|
||||
auto shape_in = col_ts->dims();
|
||||
auto type_in = col_ts->type();
|
||||
std::vector<dsize_t> v;
|
||||
v.reserve(shape_in->size());
|
||||
v.assign(shape_in->begin(), shape_in->end());
|
||||
TensorShape shape(v);
|
||||
DataType::Type dest = DataType::DE_UNKNOWN;
|
||||
#define CASE(t) \
|
||||
case TensorType_##t: \
|
||||
dest = DataType::Type::t; \
|
||||
break
|
||||
|
||||
switch (type_in) {
|
||||
CASE(DE_BOOL);
|
||||
CASE(DE_INT8);
|
||||
CASE(DE_UINT8);
|
||||
CASE(DE_INT16);
|
||||
CASE(DE_UINT16);
|
||||
CASE(DE_INT32);
|
||||
CASE(DE_UINT32);
|
||||
CASE(DE_INT64);
|
||||
CASE(DE_UINT64);
|
||||
CASE(DE_FLOAT16);
|
||||
CASE(DE_FLOAT32);
|
||||
CASE(DE_FLOAT64);
|
||||
CASE(DE_STRING);
|
||||
}
|
||||
#undef CASE
|
||||
|
||||
DataType type(dest);
|
||||
std::shared_ptr<Tensor> ts =
|
||||
std::make_shared<Tensor>(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize());
|
||||
// Next we restore the real data which can be embedded or stored separately.
|
||||
if (ts->SizeInBytes() != data.GetSize()) {
|
||||
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
|
||||
<< "Dumping tensor\n"
|
||||
<< *ts << "\n";
|
||||
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
|
||||
}
|
||||
*out = std::move(ts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchFetchRequest::RestoreRows(TensorTable *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
auto num_elements = row_id_.size();
|
||||
auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer());
|
||||
TensorTable tbl;
|
||||
tbl.reserve(num_elements);
|
||||
ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes());
|
||||
for (auto i = 0; i < num_elements; ++i) {
|
||||
auto len = offset_array[i + 1] - offset_array[i];
|
||||
TensorRow row;
|
||||
row.setId(row_id_.at(i));
|
||||
if (len > 0) {
|
||||
ReadableSlice row_data(all, offset_array[i], len);
|
||||
// Next we de-serialize flat buffer to get back each column
|
||||
auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
|
||||
auto msg_sz = msg->size_of_this();
|
||||
// Start of the tensor data
|
||||
auto ts_offset = msg_sz;
|
||||
row.reserve(msg->column()->size());
|
||||
for (auto k = 0; k < msg->column()->size(); ++k) {
|
||||
auto col_ts = msg->column()->Get(k);
|
||||
std::shared_ptr<Tensor> ts;
|
||||
ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
|
||||
RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts));
|
||||
row.push_back(ts);
|
||||
ts_offset += data.GetSize();
|
||||
}
|
||||
}
|
||||
tbl.push_back(std::move(row));
|
||||
}
|
||||
*out = std::move(tbl);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
|
||||
try {
|
||||
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
|
||||
std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
|
||||
v.reserve(map.size());
|
||||
for (auto &column : map) {
|
||||
auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second);
|
||||
v.push_back(c);
|
||||
}
|
||||
auto v_off = fbb_->CreateVector(v);
|
||||
auto final_off = CreateSchemaMsg(*fbb_, v_off);
|
||||
fbb_->Finish(final_off);
|
||||
buf_ = fbb_->GetBufferPointer();
|
||||
len_of_buf_ = fbb_->GetSize();
|
||||
return Status::OK();
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() {
|
||||
if (column_name_id_map_.empty()) {
|
||||
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer());
|
||||
auto v = map_msg->column();
|
||||
for (auto i = 0; i < v->size(); ++i) {
|
||||
auto col = map_msg->column()->Get(i);
|
||||
column_name_id_map_.emplace(col->name()->str(), col->id());
|
||||
}
|
||||
}
|
||||
return column_name_id_map_;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,225 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_CACHE_REQ_H_
|
||||
#define DATASET_ENGINE_CACHE_REQ_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "./de_tensor_generated.h"
|
||||
#include "dataset/core/tensor_row.h"
|
||||
#include "dataset/util/slice.h"
|
||||
#include "dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief CacheClient communicates with CacheServer using Requests.
|
||||
class BaseRequest {
|
||||
public:
|
||||
// Request types
|
||||
enum class RequestType : int16_t {
|
||||
kCacheRow = 0,
|
||||
kBatchFetchRows = 1,
|
||||
kCreateCache = 2,
|
||||
kPurgeCache = 3,
|
||||
kDestroyCache = 4,
|
||||
kGetStat = 5,
|
||||
kCacheSchema = 6,
|
||||
kFetchSchema = 7,
|
||||
kBuildPhaseDone = 8,
|
||||
// Add new request before it.
|
||||
kRequestUnknown = 32767
|
||||
};
|
||||
// For kCreateCache
|
||||
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
|
||||
friend class CacheServer;
|
||||
/// \brief Base class of a cache server request
|
||||
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
|
||||
/// \param type Type of the request
|
||||
explicit BaseRequest(connection_id_type connection_id, RequestType type)
|
||||
: type_(type), connection_id_(connection_id) {}
|
||||
virtual ~BaseRequest() = default;
|
||||
/// \brief Wait for the completion of a request
|
||||
/// \return Status returned from the cache server
|
||||
Status Wait() {
|
||||
RETURN_IF_NOT_OK(wp_.Wait());
|
||||
return rc_;
|
||||
}
|
||||
|
||||
/// \brief Getter function of the current connection id
|
||||
/// \return Connection id
|
||||
connection_id_type GetServerConnectionId() const { return connection_id_; }
|
||||
|
||||
private:
|
||||
RequestType type_;
|
||||
connection_id_type connection_id_;
|
||||
Status rc_;
|
||||
WaitPost wp_;
|
||||
};
|
||||
/// \brief Request to cache a single TensorRow
|
||||
class CacheRowRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie)
|
||||
: BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {}
|
||||
~CacheRowRequest() = default;
|
||||
|
||||
/// \brief Serialize a TensorRow for streaming to the cache server
|
||||
/// \param row TensorRow
|
||||
/// \return Status object
|
||||
Status SerializeCacheRowRequest(const TensorRow &row);
|
||||
/// \brief Return the row id assigned to this row for non-mappable dataset
|
||||
/// \return row id of the cached row
|
||||
row_id_type GetRowIdAfterCache() { return row_id_from_server_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
|
||||
row_id_type row_id_from_server_;
|
||||
std::vector<const void *> buffers_;
|
||||
std::string cookie_;
|
||||
|
||||
/// \brief Private function to serialize one TensorRow
|
||||
/// \param row TensorRow
|
||||
/// \return Status object
|
||||
Status SerializeTensorRowHeader(const TensorRow &row);
|
||||
/// \brief Private function to serialize one Tensor
|
||||
/// \param ts_ptr Tensor
|
||||
/// \return Status object
|
||||
Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off);
|
||||
};
|
||||
/// \brief Request to fetch rows in batch
|
||||
class BatchFetchRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
friend class CacheService;
|
||||
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id)
|
||||
: BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {}
|
||||
Status RestoreRows(TensorTable *out);
|
||||
|
||||
private:
|
||||
std::vector<row_id_type> row_id_;
|
||||
MemGuard<uint8_t> mem_;
|
||||
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
|
||||
};
|
||||
/// \brief Request to create a cache for the current connection
|
||||
class CreationCacheRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
/// \brief Constructor
|
||||
/// \param connection_id
|
||||
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
|
||||
/// \param flag Attributes of the cache.
|
||||
explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz,
|
||||
CreateCacheFlag flag = CreateCacheFlag::kNone)
|
||||
: BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {}
|
||||
|
||||
std::string cookie() const { return cookie_; }
|
||||
|
||||
private:
|
||||
uint64_t cache_mem_sz;
|
||||
CreateCacheFlag flag_;
|
||||
std::string cookie_;
|
||||
};
|
||||
/// \brief Request to purge a cache.
|
||||
class PurgeCacheRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {}
|
||||
};
|
||||
/// \brief Request to destroy a cache
|
||||
class DestroyCacheRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit DestroyCacheRequest(connection_id_type connection_id)
|
||||
: BaseRequest(connection_id, RequestType::kDestroyCache) {}
|
||||
};
|
||||
/// \brief Obtain the statistics of the current connection
|
||||
class GetStatRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
friend class CacheService;
|
||||
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {}
|
||||
row_id_type GetMinRowId() const {
|
||||
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||
return msg->min_row_id();
|
||||
}
|
||||
row_id_type GetMaxRowId() const {
|
||||
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||
return msg->max_row_id();
|
||||
}
|
||||
int64_t GetNumMemCached() const {
|
||||
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||
return msg->num_mem_cached();
|
||||
}
|
||||
int64_t GetNumDiskCached() const {
|
||||
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||
return msg->num_disk_cached();
|
||||
}
|
||||
uint8_t GetState() const {
|
||||
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||
return msg->state();
|
||||
}
|
||||
|
||||
private:
|
||||
MemGuard<uint8_t> mem_;
|
||||
};
|
||||
/// \brief Request to cache a schema
|
||||
class CacheSchemaRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit CacheSchemaRequest(connection_id_type connection_id)
|
||||
: BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {}
|
||||
~CacheSchemaRequest() = default;
|
||||
|
||||
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
|
||||
const void *GetBuffer() const { return buf_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
|
||||
const void *buf_;
|
||||
int64_t len_of_buf_;
|
||||
};
|
||||
/// \brief Request to fetch a schema
|
||||
class FetchSchemaRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit FetchSchemaRequest(connection_id_type connection_id)
|
||||
: BaseRequest(connection_id, RequestType::kFetchSchema) {}
|
||||
~FetchSchemaRequest() = default;
|
||||
|
||||
std::unordered_map<std::string, int32_t> GetColumnMap();
|
||||
|
||||
private:
|
||||
MemGuard<uint8_t> mem_;
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map_;
|
||||
};
|
||||
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
|
||||
class BuildPhaseDoneRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
|
||||
: BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {}
|
||||
|
||||
private:
|
||||
std::string cookie_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,98 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_CACHE_SERVER_H_
|
||||
#define DATASET_ENGINE_CACHE_SERVER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "dataset/engine/cache/cache_service.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/util/arena.h"
|
||||
#include "dataset/util/cache_pool.h"
|
||||
#include "dataset/util/lock.h"
|
||||
#include "dataset/util/service.h"
|
||||
#include "dataset/util/services.h"
|
||||
#include "dataset/util/system_pool.h"
|
||||
#include "dataset/util/queue.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class BaseRequest;
|
||||
/// \brief A server which provides CacheService services.
|
||||
class CacheServer : public Service {
|
||||
public:
|
||||
friend class Services;
|
||||
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
|
||||
|
||||
CacheServer(const CacheServer &) = delete;
|
||||
CacheServer &operator=(const CacheServer &) = delete;
|
||||
CacheServer(CacheServer &&) = delete;
|
||||
CacheServer &operator=(CacheServer &) = delete;
|
||||
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
|
||||
Status DoServiceStart() override;
|
||||
Status DoServiceStop() override;
|
||||
~CacheServer() { (void)ServiceStop(); }
|
||||
|
||||
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
|
||||
/// \param rq
|
||||
/// \return Status object
|
||||
Status PushRequest(BaseRequest *rq) {
|
||||
RETURN_UNEXPECTED_IF_NULL(rq);
|
||||
RETURN_IF_NOT_OK(cache_q_->Add(rq));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutable RWLock rwLock_;
|
||||
std::string top_;
|
||||
cache_index all_caches_;
|
||||
std::shared_ptr<Queue<BaseRequest *>> cache_q_;
|
||||
TaskGroup vg_;
|
||||
int32_t num_workers_;
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param spill_path Top directory for spilling buffers to.
|
||||
/// \param num_workers Number of threads for handling requests.
|
||||
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3);
|
||||
|
||||
/// \brief Locate a cache service from connection id.
|
||||
/// \return Pointer to cache service. Null if not found
|
||||
CacheService *GetService(connection_id_type id) const;
|
||||
|
||||
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
|
||||
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
|
||||
/// a special unique cookie.
|
||||
/// \param[in] connection_id This is from a Cache client.
|
||||
/// \param[in] cache_mem_sz
|
||||
/// \param[in] flag
|
||||
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
|
||||
/// \return Status object
|
||||
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag,
|
||||
std::string *out_cookie);
|
||||
|
||||
/// \brief Entry point for all server threads.
|
||||
Status ServerRequest();
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_CORE_CACHE_TENSOR_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,143 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_CACHE_SERVICE_H_
|
||||
#define DATASET_ENGINE_CACHE_SERVICE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "./de_tensor_generated.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/cache/cache_request.h"
|
||||
#include "dataset/util/arena.h"
|
||||
#include "dataset/util/btree.h"
|
||||
#include "dataset/util/cache_pool.h"
|
||||
#include "dataset/util/service.h"
|
||||
#include "dataset/util/services.h"
|
||||
#include "dataset/util/system_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
struct CacheStat;
|
||||
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
|
||||
/// created to support spilling
|
||||
class CacheService : public Service {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
using row_map = BPlusTree<row_id_type, CachePool::key_type>;
|
||||
|
||||
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase };
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
|
||||
/// \param root Spill path. Empty string means no spilling
|
||||
/// \param generate_id If the cache service should generate row id for buffer that is cached.
|
||||
/// For non-mappable dataset, this should be set to true.
|
||||
CacheService(uint64_t mem_sz, const std::string &root, bool generate_id);
|
||||
~CacheService();
|
||||
|
||||
/// \brief For fixed size memory, we will create an Arena.
|
||||
/// \return false if unlimited memory.
|
||||
bool UseArena();
|
||||
|
||||
Status DoServiceStart() override;
|
||||
Status DoServiceStop() override;
|
||||
|
||||
/// \brief Main function to cache a row which is in form a series of buffers.
|
||||
/// The first buffer is a Google flatbuffer which describes the rest of the buffers followed.
|
||||
/// \param[in] buf Vector of buffer
|
||||
/// \param[out] row_id_generated The row id assigned to this row if any
|
||||
/// \return Status object
|
||||
Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated);
|
||||
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
|
||||
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
|
||||
/// \param[in] v A vector of row id.
|
||||
/// \param[out] out A contiguous memory buffer that holds the requested rows.
|
||||
/// \return Status object
|
||||
Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const;
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return Spilling path
|
||||
Path GetSpillPath() const;
|
||||
/// \brief A structure returned from the cache server for statistics request.
|
||||
class ServiceStat {
|
||||
public:
|
||||
using state_type = std::underlying_type<State>::type;
|
||||
ServiceStat() : min_(0), max_(0), state_(0) {}
|
||||
CachePool::CacheStat stat_{};
|
||||
row_id_type min_;
|
||||
row_id_type max_;
|
||||
state_type state_;
|
||||
};
|
||||
/// \brief Statistics for the current service
|
||||
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
|
||||
/// \return Status Object
|
||||
Status GetStat(ServiceStat *);
|
||||
/// \brief Cache schema
|
||||
/// \param buf A Google Flatbuffer that contains the schema
|
||||
/// \param len size of the buffer
|
||||
/// \return Status object
|
||||
Status CacheSchema(const void *buf, int64_t len);
|
||||
/// \brief Fetch schema
|
||||
/// \param out A contiguous memory that contains the serialized form of schema.
|
||||
/// \return Status object
|
||||
Status FetchSchema(MemGuard<uint8_t> *out) const;
|
||||
/// \brief Purge the content of a cache
|
||||
/// \return Status object
|
||||
Status Purge();
|
||||
/// \brief Overload the << operator to print a cache service
|
||||
/// \param out std::ostream
|
||||
/// \param cs A cache service
|
||||
/// \return std::ostream
|
||||
friend std::ostream &operator<<(std::ostream &out, const CacheService &cs);
|
||||
/// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient
|
||||
/// is the creator
|
||||
/// \return Cookie
|
||||
std::string cookie() const { return cookie_; }
|
||||
/// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and
|
||||
/// a read phase.
|
||||
/// \return True if has two phases.
|
||||
bool HasBuildPhase() const { return generate_id_; }
|
||||
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
|
||||
/// \return Status object
|
||||
Status BuildPhaseDone();
|
||||
|
||||
private:
|
||||
mutable RWLock rw_lock_;
|
||||
std::string root_;
|
||||
uint64_t cache_mem_sz_;
|
||||
std::shared_ptr<CachePool> cp_;
|
||||
std::shared_ptr<row_map> map_;
|
||||
std::atomic<row_id_type> next_id_;
|
||||
bool generate_id_;
|
||||
std::atomic<CachePool::key_type> schema_key_;
|
||||
std::string cookie_;
|
||||
State st_;
|
||||
|
||||
/// \brief Private function to generate a row id
|
||||
/// \return Row id assigned.
|
||||
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
|
@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
namespace mindspore.dataset;
|
||||
|
||||
/// Type of a Tensor
|
||||
enum TensorType : byte {
|
||||
DE_UNKNOWN = 0,
|
||||
DE_BOOL = 1,
|
||||
DE_INT8 = 2,
|
||||
DE_UINT8 = 3,
|
||||
DE_INT16 = 4,
|
||||
DE_UINT16 = 5,
|
||||
DE_INT32 = 6,
|
||||
DE_UINT32 = 7,
|
||||
DE_INT64 = 8,
|
||||
DE_UINT64 = 9,
|
||||
DE_FLOAT16 = 10,
|
||||
DE_FLOAT32 = 11,
|
||||
DE_FLOAT64 = 12,
|
||||
DE_STRING = 13
|
||||
}
|
||||
|
||||
/// The meta information of a Tensor
|
||||
/// \note Only the type and shape are considered meta information. Tensor data is excluded.
|
||||
table TensorMetaMsg {
|
||||
dims:[int64] (required);
|
||||
type:TensorType;
|
||||
}
|
||||
|
||||
/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized.
|
||||
/// \param row_id is the row id of the TensorRow.
|
||||
/// \param column The meta information of each Tensor in the row
|
||||
/// \param size of this serialized buffer
|
||||
/// \param size of each tensor data buffer that follows
|
||||
table TensorRowHeaderMsg {
|
||||
row_id:int64;
|
||||
column:[TensorMetaMsg] (required);
|
||||
size_of_this:int64;
|
||||
data_sz:[int64] (required);
|
||||
}
|
||||
|
||||
root_type TensorRowHeaderMsg;
|
||||
|
||||
/// A row of row id's
|
||||
table TensorRowIds {
|
||||
row_id:[int64] (required);
|
||||
}
|
||||
|
||||
/// Statistics returned from each cache service
|
||||
/// \note It must match CacheService::ServiceStat
|
||||
table ServiceStatMsg {
|
||||
num_mem_cached:int64;
|
||||
num_disk_cached:int64;
|
||||
min_row_id:int64;
|
||||
max_row_id:int64;
|
||||
state:int8;
|
||||
}
|
||||
|
||||
/// Column description of each column in a schema
|
||||
table ColumnNameMsg {
|
||||
name:string;
|
||||
id:int32;
|
||||
}
|
||||
|
||||
/// Serialized form of a schema
|
||||
table SchemaMsg {
|
||||
column:[ColumnNameMsg];
|
||||
}
|
@ -0,0 +1,185 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/cache_base_op.h"
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// A print method typically used for debugging
|
||||
void CacheBase::Print(std::ostream &out, bool show_all) const {
|
||||
// Always show the id and name as first line regardless if this summary or detailed print
|
||||
out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:";
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCache client:\n" << *cache_client_ << "\n\n";
|
||||
}
|
||||
}
|
||||
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
// info from it's previous execution and then initializes itself so that it can be executed
|
||||
// again.
|
||||
Status CacheBase::Reset() {
|
||||
if (sampler_ != nullptr) {
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
}
|
||||
// Wake up the workers to get them going again in a new epoch
|
||||
MS_LOG(DEBUG) << Name() << " resetting.";
|
||||
epoch_sync_.Set();
|
||||
return Status::OK();
|
||||
}
|
||||
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, sampler),
|
||||
cache_client_(cache_client),
|
||||
rows_per_buffer_(rows_per_buf),
|
||||
// We can cause deadlock if this internal Connector size is too small.
|
||||
keys_miss_(num_workers_, 1, 1024) {
|
||||
io_block_queues_.Init(num_workers, op_connector_size);
|
||||
}
|
||||
// Common function to fetch samples from the sampler and send them using the io_block_queues to
|
||||
// the parallel workers
|
||||
Status CacheBase::FetchSamplesToWorkers() {
|
||||
int64_t buf_cnt = 0;
|
||||
int64_t wait_cnt = 0;
|
||||
do {
|
||||
epoch_sync_.Clear();
|
||||
std::vector<row_id_type> keys;
|
||||
int64_t row_cnt = 0;
|
||||
keys.reserve(rows_per_buffer_);
|
||||
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
while (!sampler_buffer->eoe()) {
|
||||
TensorRow sample_row;
|
||||
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
|
||||
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
|
||||
keys.push_back(*itr);
|
||||
++row_cnt;
|
||||
if (row_cnt % rows_per_buffer_ == 0) {
|
||||
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||
keys.clear();
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
if (!keys.empty()) {
|
||||
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||
}
|
||||
// send the eoe
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
// If repeat but the not last repeat, wait for reset.
|
||||
if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
|
||||
RETURN_IF_NOT_OK(epoch_sync_.Wait());
|
||||
} else {
|
||||
// We can break out from the loop.
|
||||
break;
|
||||
}
|
||||
} while (true);
|
||||
// Flow the eof before exit
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
|
||||
// Ask all the workers to quit.
|
||||
for (int32_t i = 0; i < num_workers_; i++) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheBase::FetchFromCache(int32_t worker_id) {
|
||||
int64_t buffer_id = worker_id;
|
||||
std::unique_ptr<IOBlock> blk;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
|
||||
if (blk->eof()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
|
||||
} else if (blk->eoe()) {
|
||||
if (AllowCacheMiss()) {
|
||||
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
|
||||
// a sampler, send a eoe to physical leaf op as well.
|
||||
std::vector<row_id_type> eoe;
|
||||
eoe.push_back(eoe_row_id);
|
||||
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
|
||||
} else {
|
||||
std::vector<int64_t> keys;
|
||||
RETURN_IF_NOT_OK(blk->GetKeys(&keys));
|
||||
if (keys.empty()) {
|
||||
// empty key is a quit signal for workers
|
||||
break;
|
||||
}
|
||||
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
|
||||
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
|
||||
TensorTable ttbl;
|
||||
RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl));
|
||||
auto row_it = ttbl.begin();
|
||||
std::vector<row_id_type> cache_miss;
|
||||
cache_miss.reserve(keys.size());
|
||||
for (auto row_id : keys) {
|
||||
auto &row = *row_it;
|
||||
if (row.empty()) {
|
||||
if (AllowCacheMiss()) {
|
||||
cache_miss.push_back(row_id);
|
||||
} else {
|
||||
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
}
|
||||
que->push_back(std::move(row));
|
||||
++row_it;
|
||||
}
|
||||
db->set_tensor_table(std::move(que));
|
||||
if (AllowCacheMiss()) {
|
||||
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
|
||||
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
|
||||
buffer_id += num_workers_;
|
||||
}
|
||||
} while (true);
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheBase::RegisterResources() {
|
||||
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
return Status::OK();
|
||||
}
|
||||
CacheBase::~CacheBase() {}
|
||||
Status CacheBase::UpdateColumnMapFromCache() {
|
||||
Status rc;
|
||||
// Get the schema from the server. It may not be there yet. So tolerate the error.
|
||||
if (column_name_id_map_.empty()) {
|
||||
rc = cache_client_->FetchSchema(&column_name_id_map_);
|
||||
if (rc == Status(StatusCode::kFileNotExist)) {
|
||||
MS_LOG(DEBUG) << "Schema not in the server yet.";
|
||||
rc = Status::OK();
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,108 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "dataset/engine/cache/cache_client.h"
|
||||
#include "dataset/engine/cache/cache_service.h"
|
||||
#include "dataset/engine/datasetops/parallel_op.h"
|
||||
#include "dataset/engine/datasetops/repeat_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/util/queue.h"
|
||||
#include "dataset/util/wait_post.h"
|
||||
#include "dataset/engine/datasetops/cache_base_op.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
|
||||
/// \see CacheOp
|
||||
/// \see CacheLookupOp
|
||||
class CacheBase : public ParallelOp {
|
||||
public:
|
||||
/// \brief Base class constructor
|
||||
/// \param num_workers Number of parallel workers
|
||||
/// \param op_connector_size Connector size
|
||||
/// \param rows_per_buf Number of rows per buffer
|
||||
/// \param cache_client CacheClient for communication to the CacheServer
|
||||
/// \param sampler Sampler which is mandatory
|
||||
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
|
||||
/// \brief Destructor
|
||||
~CacheBase();
|
||||
|
||||
constexpr static int eoe_row_id = -1;
|
||||
|
||||
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
/// info from it's previous execution and then initializes itself so that it can be executed
|
||||
/// again.
|
||||
/// \return Status - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
/// \param out The output stream to write output to
|
||||
/// \param show_all A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief << Stream output operator overload
|
||||
/// \notes This allows you to write the debug print info using stream operators
|
||||
/// \param out reference to the output stream being overloaded
|
||||
/// \param mo reference to the CacheOp to display
|
||||
/// \return the output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) {
|
||||
mo.Print(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
/// \brief Getter for the cache client
|
||||
/// \return shared ptr to the cache client
|
||||
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
|
||||
/// \brief Setter for the cache client
|
||||
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
|
||||
/// \brief Derived class must implement this method if a cache miss is treated as error
|
||||
virtual bool AllowCacheMiss() = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<CacheClient> cache_client_;
|
||||
WaitPost epoch_sync_;
|
||||
int32_t rows_per_buffer_;
|
||||
Connector<std::vector<row_id_type>> keys_miss_;
|
||||
|
||||
/// \brief Common function to register resources for interrupt
|
||||
/// \note Derived should override this function for extra resources to be registered
|
||||
virtual Status RegisterResources();
|
||||
/// \brief This function is called by main thread to send samples to the worker thread.
|
||||
/// \note It is a non-virtual function
|
||||
/// \return Status object
|
||||
Status FetchSamplesToWorkers();
|
||||
/// \brief This function is called by each worker to fetch rows from the cache server for a given set of
|
||||
/// sample row id's
|
||||
/// \return Status object
|
||||
Status FetchFromCache(int32_t worker_id);
|
||||
/// \brief Get the column map from cache server
|
||||
Status UpdateColumnMapFromCache();
|
||||
|
||||
private:
|
||||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
|
@ -0,0 +1,130 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/system/crc32c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
build_num_workers_ = cfg->num_parallel_workers();
|
||||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
build_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
Status CacheLookupOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Cache client for CacheLookupOp is missing session id");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object and does some init on it
|
||||
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_,
|
||||
build_cache_client_, build_sampler_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheLookupOp::operator()() {
|
||||
if (!sampler_) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"CacheLookupOp requires a sampler before it can be executed!");
|
||||
}
|
||||
RETURN_IF_NOT_OK(RegisterResources());
|
||||
// Kick off the workers
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1)));
|
||||
// required task group sync after launching workers
|
||||
TaskManager::FindMe()->Post();
|
||||
// We have to wait until the leaf op has handshake with us.
|
||||
RETURN_IF_NOT_OK(leaf_op_wp_.Wait());
|
||||
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheLookupOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheLookupOp::ResetSampler() { return Status::OK(); }
|
||||
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||
// We act like a sampler and as a dataset op. During handshake with leaf op,
|
||||
// We must wait until the leaf op has indexed everything.
|
||||
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op));
|
||||
// Now we notify the main thread handshake has finished.
|
||||
leaf_op_wp_.Set();
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
|
||||
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
|
||||
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
std::vector<row_id_type> cache_miss;
|
||||
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
|
||||
// Ignore the case we have no cache miss, we can't return empty samples.
|
||||
while (cache_miss.empty()) {
|
||||
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
|
||||
}
|
||||
// Special code for eoe
|
||||
if (cache_miss.at(0) == eoe_row_id) {
|
||||
*out_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
std::shared_ptr<Tensor> sample_ts;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size()));
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
|
||||
auto idPtr = sample_ts->begin<int64_t>();
|
||||
for (auto i = 0; i < cache_miss.size(); ++i) {
|
||||
*idPtr = cache_miss.at(i);
|
||||
++idPtr;
|
||||
}
|
||||
TensorRow row;
|
||||
row.push_back(sample_ts);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheLookupOp::RegisterResources() {
|
||||
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
||||
RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks()));
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheLookupOp::ComputeColMap() {
|
||||
// We don't know the column map at this point unless we contact the cache server
|
||||
// to fetch the schema but the cache server may not have it at this point either.
|
||||
// So we will just return OK and let MergeOp (our parent) to handle it.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status CacheLookupOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<CacheLookupOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "dataset/engine/datasetops/cache_base_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
|
||||
/// \note For non-mappable dataset, please see CacheOp
|
||||
/// \see CacheOp
|
||||
class CacheLookupOp : public CacheBase, public Sampler {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
/// \brief Builder constructor. Creates the builder object.
|
||||
/// \note No default args
|
||||
Builder();
|
||||
|
||||
/// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
/// Setter method.
|
||||
/// \treturn Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
build_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||
build_op_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||
build_cache_client_ = cache_client;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
build_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief The builder "build" method creates the final object and does some init on it.
|
||||
/// \param ptr The shared_ptr to the new CacheLookupOp object
|
||||
/// \return Status
|
||||
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
|
||||
|
||||
private:
|
||||
int32_t build_num_workers_;
|
||||
int32_t rows_per_buffer_;
|
||||
int32_t build_op_connector_size_;
|
||||
std::shared_ptr<CacheClient> build_cache_client_;
|
||||
std::shared_ptr<Sampler> build_sampler_;
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
// \return Status The error code return
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
/// \brief Constructor
|
||||
/// \note It takes the same argument as the base class.
|
||||
/// \see CacheBase
|
||||
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
|
||||
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
|
||||
~CacheLookupOp() = default;
|
||||
// As a parallel op, we override these two functions
|
||||
Status operator()() override;
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
// As a sampler, we override the following functions
|
||||
Status ResetSampler() override;
|
||||
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
|
||||
Status InitSampler() override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
bool AllowCacheMiss() override { return true; }
|
||||
std::string Name() const override { return "CacheLookupOp"; }
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
protected:
|
||||
Status ComputeColMap() override;
|
||||
|
||||
private:
|
||||
WaitPost leaf_op_wp_;
|
||||
|
||||
Status RegisterResources() override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,196 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "dataset/core/tensor_row.h"
|
||||
#include "dataset/engine/cache/cache_client.h"
|
||||
#include "dataset/engine/datasetops/parallel_op.h"
|
||||
#include "dataset/engine/dataset_iterator.h"
|
||||
#include "dataset/util/queue.h"
|
||||
#include "dataset/util/semaphore.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
|
||||
/// stream
|
||||
class CacheMergeOp : public ParallelOp {
|
||||
public:
|
||||
// Some handshake structures among the main thread, cleaner threads and parallel op threads.
|
||||
class TensorRowRequest {
|
||||
public:
|
||||
enum class State : uint8_t {
|
||||
kEmpty = 0, // No row in the deque
|
||||
kDirty = 1, // Cleaner hasn't flushed it to the cache server yet.
|
||||
kClean = 2 // The row has been flushed already.
|
||||
};
|
||||
explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {}
|
||||
~TensorRowRequest() = default;
|
||||
State GetState() const { return st_; }
|
||||
void SetState(State newState) { st_ = newState; }
|
||||
Status Wait(TensorRow *out);
|
||||
void WakeUpAny(TensorRow &&row);
|
||||
Status Release(TensorRow *out);
|
||||
|
||||
private:
|
||||
std::mutex dq_mux_;
|
||||
std::atomic<State> st_;
|
||||
Semaphore use_count_;
|
||||
std::deque<TensorRow> row_;
|
||||
TensorRow cleaner_copy_;
|
||||
};
|
||||
|
||||
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
|
||||
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
|
||||
|
||||
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
|
||||
/// the arguments for constructing it. Use the builder by setting each argument
|
||||
/// with the provided set methods, and then finally call the build method to execute
|
||||
/// the actual construction.
|
||||
class Builder {
|
||||
public:
|
||||
/// Builder constructor. Creates the builder object.
|
||||
/// \note No default args
|
||||
Builder();
|
||||
|
||||
/// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
build_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||
build_op_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||
build_cache_client_ = cache_client;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method
|
||||
/// \param sampler
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
build_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method
|
||||
/// \param num_cleaners
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumCleaner(int32_t num_cleaners) {
|
||||
build_num_cleaners_ = num_cleaners;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// The builder "build" method creates the final object and does some init on it.
|
||||
/// \param ptr The shared_ptr to the new CacheMergeOp object
|
||||
/// \return Status
|
||||
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
|
||||
|
||||
private:
|
||||
int32_t build_num_workers_;
|
||||
int32_t build_op_connector_size_;
|
||||
int32_t build_num_cleaners_;
|
||||
std::shared_ptr<CacheClient> build_cache_client_;
|
||||
std::shared_ptr<Sampler> build_sampler_;
|
||||
|
||||
/// Check if the required parameters are set by the builder.
|
||||
/// \return Status The error code return
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
|
||||
/// \param opConnector Size Connector size as a derived class of ParallelOp
|
||||
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
|
||||
/// \param cache_client CacheClient to commmunicate with the Cache server
|
||||
/// \param sampler as a derived class of ParallelOp
|
||||
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
||||
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
|
||||
~CacheMergeOp();
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) {
|
||||
mo.Print(out, false);
|
||||
return out;
|
||||
}
|
||||
/// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and
|
||||
/// the threads for the cleaners.
|
||||
/// \return
|
||||
Status operator()() override;
|
||||
/// \brief Entry function for worker thread that fetch rows from CacheLookupOp
|
||||
/// \param workerId
|
||||
/// \return Status object
|
||||
Status WorkerEntry(int32_t workerId) override;
|
||||
Status PrepareNodePostAction() override;
|
||||
/// \brief Entry function for worker thread that fetch rows from the cache miss stream
|
||||
/// \param workerId
|
||||
/// \return Status object
|
||||
Status CacheMissWorkerEntry(int32_t workerId);
|
||||
Status GetRq(row_id_type row_id, TensorRowRequest **);
|
||||
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for eoe handling
|
||||
/// \param worker_id
|
||||
/// \return Status object
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
protected:
|
||||
Status ComputeColMap() override;
|
||||
|
||||
private:
|
||||
std::mutex mux_;
|
||||
std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_;
|
||||
std::unique_ptr<Queue<row_id_type>> io_que_;
|
||||
std::shared_ptr<CacheClient> cache_client_;
|
||||
int32_t num_cleaners_;
|
||||
|
||||
/// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
|
||||
/// moving cache miss TensorRow into the CacheServer.
|
||||
/// \return Status object
|
||||
Status Cleaner();
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
|
@ -0,0 +1,219 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/engine/datasetops/repeat_op.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
build_num_workers_ = cfg->num_parallel_workers();
|
||||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
build_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
Status CacheOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object and does some init on it
|
||||
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_,
|
||||
build_sampler_);
|
||||
RETURN_IF_NOT_OK((*ptr)->InitCache());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of CacheOp
|
||||
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
|
||||
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler),
|
||||
num_guys_in_(0),
|
||||
phase_(Phase::kBuildPhase) {}
|
||||
|
||||
// Destructor
|
||||
CacheOp::~CacheOp() = default;
|
||||
|
||||
// Private function for cache setup/init work just after construction
|
||||
Status CacheOp::InitCache() { return Status::OK(); }
|
||||
|
||||
// This class functor will provide the master loop that drives the logic for performing the work
|
||||
Status CacheOp::operator()() {
|
||||
if (!sampler_) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"CacheOp requires a sampler before it can be executed!");
|
||||
}
|
||||
RETURN_IF_NOT_OK(RegisterResources());
|
||||
// Kick off the workers
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1)));
|
||||
// required task group sync after launching workers
|
||||
TaskManager::FindMe()->Post();
|
||||
// Wait for the workers to finish caching the rows.
|
||||
RETURN_IF_NOT_OK(WaitForCachingAllRows());
|
||||
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheOp::CacheAllRows(int32_t worker_id) {
|
||||
// If the current phase is to fill the cache, do it then.
|
||||
if (phase_ == Phase::kBuildPhase) {
|
||||
// We will take the chance to cache the schema at the server.
|
||||
// Just do it once and pick one worker to do it.
|
||||
if (worker_id == 0) {
|
||||
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
|
||||
}
|
||||
MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id;
|
||||
// SAVE mode loop
|
||||
std::unique_ptr<DataBuffer> db_ptr;
|
||||
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
|
||||
while (!db_ptr->eof()) {
|
||||
if (!db_ptr->eoe()) {
|
||||
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
|
||||
} else {
|
||||
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
|
||||
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
|
||||
// the eoe to indicate the end of the epoch, we should next expect to get the eof.
|
||||
// Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch
|
||||
// from again.
|
||||
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
|
||||
if (!db_ptr->eof()) {
|
||||
RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child.");
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
|
||||
}
|
||||
}
|
||||
// Let the main guy know we are done.
|
||||
auto last_guy_in = num_guys_in_.fetch_add(1);
|
||||
if ((last_guy_in + 1) == num_workers_) {
|
||||
rows_cache_done_.Set();
|
||||
} else {
|
||||
// Let's do a sync up here.
|
||||
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheOp::WaitForCachingAllRows() {
|
||||
// Wait for the workers to finish caching the rows.
|
||||
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
|
||||
// Move from build phase to fetch phase if we are the one to fill the cache
|
||||
if (phase_ == Phase::kBuildPhase) {
|
||||
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
|
||||
// Move to the next phase
|
||||
phase_ = Phase::kFetchPhase;
|
||||
}
|
||||
// Get statistics from the server, and if we are not the one to create the cache,
|
||||
// wait until the state changed from build phase to fetch base.
|
||||
CacheClient::ServiceStat stat{};
|
||||
bool BuildPhaseDone = true;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
|
||||
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase);
|
||||
if (!BuildPhaseDone) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
} while (!BuildPhaseDone);
|
||||
const row_id_type min_key = stat.min_row_id;
|
||||
const row_id_type max_key = stat.max_row_id;
|
||||
num_rows_ = max_key - min_key + 1;
|
||||
MS_LOG(INFO) << "Number of rows cached: " << num_rows_;
|
||||
MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached;
|
||||
MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached;
|
||||
// Now all rows are cached and we have done a sync point check up. Next phase is
|
||||
// is pick up fetch input from sampler and pass up to the caller.
|
||||
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(CacheAllRows(worker_id));
|
||||
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheOp::RegisterResources() {
|
||||
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
||||
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Base-class override for setting specific CacheOp configurations. This code will be called
|
||||
// during the execution tree prepare phase BEFORE traversing down to child operators.
|
||||
uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; }
|
||||
// Base-class override for special eoe handler.
|
||||
// CacheOp must override this because it shall not perform default handling of eoe. Instead
|
||||
// the CacheOp manages actions related to the end of the epoch.
|
||||
Status CacheOp::EoeReceived(int32_t worker_id) {
|
||||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
}
|
||||
// Base-class override for handling cases when an eof is received.
|
||||
Status CacheOp::EofReceived(int32_t worker_id) {
|
||||
// eofReceived is overloaded because we want to manually handle this eof.
|
||||
// Specifically, the default behaviour is to pack it and flow it up to the next connection.
|
||||
// In this case, we want a no-op behaviour so that we can perform correct action.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pre-Visitor accept method for NodePass
|
||||
Status CacheOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call the pre-visitation
|
||||
return p->PreRunOnNode(shared_from_base<CacheOp>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status CacheOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
|
||||
}
|
||||
|
||||
// A public wrapper for creating the cache through the client
|
||||
Status CacheOp::CreateCache(uint32_t cache_crc) {
|
||||
// This is a non-mappable cache op so the id's need to be generated.
|
||||
// Construct the cache
|
||||
const bool generate_ids = true;
|
||||
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
|
||||
if (rc.get_code() == StatusCode::kDuplicateKey) {
|
||||
// We are told the cache has been created already. So we skip the build phase.
|
||||
phase_ = Phase::kFetchPhase;
|
||||
rc = Status::OK();
|
||||
}
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue