diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index 8d7da15b22..4b84c4d797 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -47,6 +47,8 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") +ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR}) + ################## Include sub-modules ############################### add_subdirectory(util) add_subdirectory(core) @@ -55,7 +57,7 @@ add_subdirectory(engine) add_subdirectory(api) add_subdirectory(text) ###################################################################### -add_dependencies(core utils) +add_dependencies(utils core) add_dependencies(kernels-image core) add_dependencies(kernels-data core) add_dependencies(kernels core) @@ -89,6 +91,8 @@ set(submodules $ $ $ + $ + $ $ $ $ @@ -106,6 +110,8 @@ else () add_library(_c_dataengine SHARED ${submodules}) endif () +add_dependencies(_c_dataengine generated_engine_files) + set_target_properties(_c_dataengine PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}" SUFFIX "${PYTHON_MODULE_EXTENSION}" diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 78fcdb7dd4..6d4a60cdc5 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -21,8 +21,10 @@ #include "common/utils.h" #include "dataset/core/tensor.h" +#include "dataset/engine/cache/cache_client.h" #include "dataset/engine/dataset_iterator.h" #include "dataset/engine/datasetops/bucket_batch_by_length_op.h" +#include "dataset/engine/datasetops/cache_op.h" #include "dataset/engine/datasetops/filter_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" @@ -34,6 +36,7 @@ #include "dataset/engine/datasetops/source/random_data_op.h" #include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/engine/datasetops/source/voc_op.h" +#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/kernels/py_func_op.h" #include "dataset/util/random.h" #include "dataset/util/status.h" @@ -441,6 +444,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * MapOp::Builder map_builder; std::vector> tensor_op_list; std::vector project_columns; + std::shared_ptr cache_client = nullptr; + int num_workers = 0; if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); @@ -456,7 +461,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * } else if (key == "columns_order") { project_columns = ToStringVector(value); } else if (key == "num_parallel_workers") { - (void)map_builder.SetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)map_builder.SetNumWorkers(num_workers); } else if (key == "prefetch_size") { (void)map_builder.SetOpConnectorSize(ToInt(value)); } else if (key == "operations") { @@ -477,6 +483,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * } if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); + } else if (key == "cache") { + cache_client = value.cast>(); } else { RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); } @@ -499,6 +507,15 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * *bottom = map_op; } + // Additionally, add a cache if required. This will go over top of the project op if one + // was created, otherwise it goes over top of the map op + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op)); + *top = cache_op; + *bottom = map_op; + } + return Status::OK(); } @@ -809,6 +826,9 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *bottom) { // Required arguments std::vector files_list; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { files_list = ToStringVector(args["dataset_files"]); @@ -828,7 +848,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "columns_list") { columns_to_load = ToStringVector(value); (void)builder->SetColumnsToLoad(columns_to_load); @@ -848,6 +869,11 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptrSetDeviceId(ToInt(value)); } else if (key == "shard_equal_rows") { (void)builder->SetShardEqualRows(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); } } } @@ -860,12 +886,27 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptrSetDataSchema(std::move(schema)); } + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because TFReaderOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder->SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder->SetSampler(std::move(sampler)); + } + std::shared_ptr tf_op; RETURN_IF_NOT_OK(builder->Build(&tf_op)); RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); *top = tf_op; - if (shuffle_required) { + if (!cache_client && shuffle_required) { const boolean estimate = true; const int64_t workers = 8; std::shared_ptr shuffle_op = nullptr; @@ -882,6 +923,15 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op)); + *top = cache_op; + *bottom = tf_op; + } + return Status::OK(); } @@ -906,6 +956,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); @@ -915,7 +967,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); @@ -926,12 +979,27 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptrSetClassIndex(ToStringMap(value)); } else if (key == "decode") { (void)builder->SetDecode(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr if_op; + RETURN_IF_NOT_OK(builder->Build(&if_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(if_op)); + *top = if_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op)); + *top = cache_op; + *bottom = if_op; + } + return Status::OK(); } @@ -1130,9 +1198,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *bottom) { // Required arguments RandomDataOp::Builder builder; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; - if (args["num_samples"].is_none()) { - std::string err_msg = "Error: num_samples is a required argument"; + if (args["total_rows"].is_none()) { + std::string err_msg = "Error: total_rows is a required argument"; RETURN_STATUS_UNEXPECTED(err_msg); } std::vector columns_to_load; @@ -1141,16 +1212,23 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); + } } } if (schema_exists) { @@ -1162,9 +1240,34 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr op; - RETURN_IF_NOT_OK(builder.Build(&op)); - *top = op; + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because RandomDataOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder.SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder.SetSampler(std::move(sampler)); + } + + std::shared_ptr random_op = nullptr; + RETURN_IF_NOT_OK(builder.Build(&random_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(random_op)); + *top = random_op; + + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op)); + *top = cache_op; + *bottom = random_op; + } + return Status::OK(); } @@ -1425,6 +1528,31 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr return Status::OK(); } +// Helper function to inject the cache operator over top of the current operation being built. +Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num_workers, + std::shared_ptr input_op, std::shared_ptr *cache_op) { + std::shared_ptr new_cache_op = nullptr; + CacheOp::Builder cache_builder; + // use the same number of workers as the leaf. We need some optimization here, the user does not + // give the cache op number of workers directly. + if (num_workers != 0) { + (void)cache_builder.SetNumWorkers(num_workers); + } + (void)cache_builder.SetClient(cache_client); + RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op)); + RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op)); + // We have now created: + // + // CacheOp + // | + // input_op + // + *cache_op = new_cache_op; + + return Status::OK(); +} + // Helper function to inject a shuffle operator over top of the current operation being built. Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, std::shared_ptr *shuffle_op) { diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 7cfc73307c..aac2d686af 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -35,6 +35,8 @@ namespace mindspore { namespace dataset { using DsOpPtr = std::shared_ptr; +class CacheClient; + // enum for the dataset operator names enum OpName { kShuffle, @@ -181,6 +183,16 @@ class DEPipeline { static Status ParsePadInfo(py::handle value, PadInfo *pad_info); + /// \brief Helper function to inject a cache operator over top of the current operation being built. + /// \param[in] cache_client The client to use for caching + /// \param[in] num_workers The number of workers to use in the cache op + /// \param[in] input_op The operator to build the cache on top of + /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the cache operator + /// \return Status return code + Status AddCacheOp(std::shared_ptr cache_client, int num_workers, std::shared_ptr input_op, + std::shared_ptr *cache_op); + /// \brief Helper function to inject a shuffle operator over top of the current operation being built. /// \param[in] shuffle_size The size to use in the shuffle buffer /// \param[in] input_op The operator to build shuffle on top of diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 403732d6b8..63bd5eccdc 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -35,6 +35,7 @@ #include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/datasetops/source/voc_op.h" +#include "dataset/engine/cache/cache_client.h" #include "dataset/engine/gnn/graph.h" #include "dataset/engine/jagged_connector.h" #include "dataset/kernels/data/concatenate_op.h" @@ -768,6 +769,11 @@ void bindInfoObjects(py::module *m) { .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); } +void bindCacheClient(py::module *m) { + (void)py::class_>(*m, "CacheClient") + .def(py::init()); +} + void bindVocabObjects(py::module *m) { (void)py::class_>(*m, "Vocab") .def(py::init<>()) @@ -939,6 +945,7 @@ PYBIND11_MODULE(_c_dataengine, m) { bindSamplerOps(&m); bindDatasetOps(&m); bindInfoObjects(&m); + bindCacheClient(&m); bindVocabObjects(&m); bindGraphData(&m); bindDependIcuTokenizerOps(&m); diff --git a/mindspore/ccsrc/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/CMakeLists.txt index 66f95d0926..e3ead16d05 100644 --- a/mindspore/ccsrc/dataset/engine/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(datasetops) add_subdirectory(opt) add_subdirectory(gnn) add_subdirectory(perf) +add_subdirectory(cache) if (ENABLE_TDTQUE) add_subdirectory(tdt) endif () @@ -17,7 +18,9 @@ add_library(engine OBJECT target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) if (ENABLE_TDTQUE) - add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf) -else() - add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf) + add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf + engine-cache-client engine-cache-server) +else () + add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf + engine-cache-client engine-cache-server) endif () diff --git a/mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt new file mode 100644 index 0000000000..5e7ebea176 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt @@ -0,0 +1,8 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_library(engine-cache-client OBJECT + cache_client.cc + cache_request.cc) +add_library(engine-cache-server OBJECT + cache_service.cc + cache_server.cc) diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/dataset/engine/cache/cache_client.cc new file mode 100644 index 0000000000..1dc97ac43a --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_client.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "dataset/engine/cache/cache_client.h" +#include "dataset/engine/cache/cache_request.h" +#include "dataset/util/bit.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) + : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} + +// print method for display cache details +void CacheClient::Print(std::ostream &out) const { + out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ + << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ + << "\n Spilling: " << std::boolalpha << spill_; +} + +Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { + CacheRowRequest rq(server_connection_id_, cookie()); + RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + if (row_id_from_server != nullptr) { + *row_id_from_server = rq.GetRowIdAfterCache(); + } + return Status::OK(); +} + +Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { + std::unique_ptr db_ptr = std::move(in); + auto num_rows = db_ptr->NumRows(); + std::vector all_rows; + if (num_rows > 0) { + all_rows.reserve(num_rows); + // Break down the DataBuffer into TensorRow. We will send the requests async + // and then do a final wait. + MemGuard rq_arr; + RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); + CacheServer &cs = CacheServer::GetInstance(); + for (auto i = 0; i < num_rows; ++i) { + TensorRow row; + auto rq = rq_arr[i]; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); + RETURN_IF_NOT_OK(cs.PushRequest(rq)); + // We can't let row go out of scope. Otherwise it will free all the tensor memory. + // So park it in the vector. When this function go out of scope, its memory + // will be freed. + all_rows.push_back(std::move(row)); + } + // Now we wait for the requests to be done. + for (auto i = 0; i < num_rows; ++i) { + auto rq = rq_arr[i]; + RETURN_IF_NOT_OK(rq->Wait()); + } + } + return Status::OK(); +} + +Status CacheClient::GetRows(const std::vector &row_id, TensorTable *out) const { + RETURN_UNEXPECTED_IF_NULL(out); + BatchFetchRequest rq(server_connection_id_, row_id); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + RETURN_IF_NOT_OK(rq.RestoreRows(out)); + return Status::OK(); +} + +Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { + UniqueLock lck(&mux_); + // To create a cache, we identify ourself at the client by: + // - the shared session id + // - a crc for the tree nodes from the cache downward + // Pack these 2 into a single 64 bit request id + // + // Consider this example: + // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch + // tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch + // These are different trees in a single session, but the user wants to share the cache. + // This is not allowed because the data of these caches are different. + // + // Consider this example: + // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch + // tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch + // These are different trees in the same session, but the cached data is the same, so it is okay + // to allow the sharing of this cache between these pipelines. + + // The CRC is computed by the tree prepare phase and passed to this function when creating the cache. + // If we already have a server_connection_id_, then it means this same cache client has already been used + // to create a cache and some other tree is trying to use the same cache. + // That is allowed, however the crc better match! + if (server_connection_id_) { + if (cache_crc_ != tree_crc) { + RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); + } + // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should + // skip the build phase. + lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. + CacheClient::ServiceStat stat{}; + RETURN_IF_NOT_OK(GetStat(&stat)); + if (stat.cache_service_state == static_cast(CacheService::State::kFetchPhase)) { + return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); + } + } else { + cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client + // Combine the session and crc. This will form our client cache identifier. + connection_id_type connection_identification = (static_cast(session_id_) << 32) | cache_crc_; + // Now execute the cache create request using this identifier and other configs + BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; + if (spill_) { + createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; + } + if (generate_id) { + createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; + } + CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + Status rc = rq.Wait(); + if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { + server_connection_id_ = rq.GetServerConnectionId(); + if (rc.IsOk()) { + // The 1st guy creating the cache will get a cookie back. + // But this object may be shared among pipelines and we don't want + // overwrite it. + cookie_ = rq.cookie(); + } + } + // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the + // CacheOp to bypass the build phase. + return rc; + } + return Status::OK(); +} + +Status CacheClient::PurgeCache() { + UniqueLock lck(&mux_); + PurgeCacheRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + return rq.Wait(); +} + +Status CacheClient::DestroyCache() { + UniqueLock lck(&mux_); + DestroyCacheRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + return rq.Wait(); +} + +Status CacheClient::GetStat(ServiceStat *stat) { + SharedLock lck(&mux_); + RETURN_UNEXPECTED_IF_NULL(stat); + GetStatRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + stat->num_disk_cached = rq.GetNumDiskCached(); + stat->num_mem_cached = rq.GetNumMemCached(); + stat->min_row_id = rq.GetMinRowId(); + stat->max_row_id = rq.GetMaxRowId(); + stat->cache_service_state = rq.GetState(); + return Status::OK(); +} + +Status CacheClient::CacheSchema(const std::unordered_map &map) { + SharedLock lck(&mux_); + CacheSchemaRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + return Status::OK(); +} + +Status CacheClient::FetchSchema(std::unordered_map *map) { + SharedLock lck(&mux_); + RETURN_UNEXPECTED_IF_NULL(map); + FetchSchemaRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + *map = rq.GetColumnMap(); + return Status::OK(); +} + +Status CacheClient::BuildPhaseDone() const { + SharedLock lck(&mux_); + BuildPhaseDoneRequest rq(server_connection_id_, cookie()); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/dataset/engine/cache/cache_client.h new file mode 100644 index 0000000000..ffdb9e9fdd --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_client.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_CACHE_CLIENT_H_ +#define DATASET_ENGINE_CACHE_CLIENT_H_ + +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/cache/cache_server.h" +#include "dataset/util/lock.h" + +namespace mindspore { +namespace dataset { +/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through +/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously +/// rows, etc. +class CacheClient { + public: + /// \brief Constructor + /// \param session_id A user assigned session id for the current pipeline + /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited + /// \param spill Spill to disk if out of memory + CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); + + /// \brief Destructor + ~CacheClient() = default; + + /// \brief Getter function for returning the current session id + /// \return session id + uint64_t session_id() const { return session_id_; } + + /// \brief Send a TensorRow to the cache server + /// \param[in] row + /// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset + /// \return return code + Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const; + + /// \brief Send a DataBuffer to the cache server + /// \param in Unique pointer of the DataBuffer to be cached + /// \return return code + Status WriteBuffer(std::unique_ptr &&in) const; + + /// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is + /// any cache miss + /// \param row_id A vector of row id's + /// \param out A TensorTable of TensorRows. + /// \return return code + Status GetRows(const std::vector &row_id, TensorTable *out) const; + + /// \brief Create a cache. + /// \param tree_crc A crc that was generated during tree prepare phase + /// \param generate_id Let the cache service generate row id + /// \return Status object + Status CreateCache(uint32_t tree_crc, bool generate_id); + + /// \brief Purge a cache. Cache can be reused after reset. + /// \return Status object + Status PurgeCache(); + + /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. + /// \return Status object + Status DestroyCache(); + + /// \brief Get the statistics from a cache. + /// \param[in/out] Pointer to a pre-allocated ServiceStat object + /// \return Status object + struct ServiceStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + row_id_type min_row_id; + row_id_type max_row_id; + int8_t cache_service_state; + }; + Status GetStat(ServiceStat *); + + /// \brief Cache the schema at the cache server + /// \param map The unordered map of the schema + /// \return Status object + Status CacheSchema(const std::unordered_map &map); + + /// \brief Fetch the schema from the cache server + /// \param map Pointer to pre-allocated map object + /// \return Status object. + Status FetchSchema(std::unordered_map *map); + + /// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache + /// client that holds cookie can be allowed to make this request + /// \return Status object + Status BuildPhaseDone() const; + + /// \brief A print method typically used for debugging + /// \param out The output stream to write output to + void Print(std::ostream &out) const; + + /// \brief Stream output operator overload + /// \return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) { + cc.Print(out); + return out; + } + + /// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it. + /// \return Cookie + std::string cookie() const { return cookie_; } + + private: + mutable RWLock mux_; + uint64_t cache_mem_sz_; + bool spill_; + // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow + // sharing of the cache. + uint32_t session_id_; + uint32_t cache_crc_; + // The server_connection_id_ is the actual id we use for operations after the cache is built + connection_id_type server_connection_id_; + // Some magic cookie returned from the cache server. + std::string cookie_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_CACHE_CLIENT_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/dataset/engine/cache/cache_request.cc new file mode 100644 index 0000000000..5485c22b6a --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_request.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "dataset/engine/cache/cache_request.h" + +namespace mindspore { +namespace dataset { + +Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { + buffers_.reserve(row.size() + 1); + RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); + buffers_.push_back(fbb_->GetBufferPointer()); + for (const auto &ts : row) { + buffers_.push_back(ts->GetBuffer()); + } + return Status::OK(); +} + +Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { + try { + fbb_ = std::make_shared(); + std::vector> v; + std::vector tensor_sz; + v.reserve(row.size()); + tensor_sz.reserve(row.size()); + // We will go through each column in the row. + for (const std::shared_ptr &ts_ptr : row) { + flatbuffers::Offset ts_off; + RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); + v.push_back(ts_off); + tensor_sz.push_back(ts_ptr->SizeInBytes()); + } + auto column_off = fbb_->CreateVector(v); + auto data_sz_off = fbb_->CreateVector(tensor_sz); + TensorRowHeaderMsgBuilder row_builder(*fbb_); + row_builder.add_column(column_off); + row_builder.add_data_sz(data_sz_off); + // Pass the row_id even if it may not be known. + row_builder.add_row_id(row.getId()); + row_builder.add_size_of_this(-1); // fill in later after we call Finish. + auto out = row_builder.Finish(); + fbb_->Finish(out); + // Now go back to fill in size_of_this in the flat buffer. + auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); + auto success = msg->mutate_size_of_this(fbb_->GetSize()); + if (!success) { + RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); + } + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + +Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, + flatbuffers::Offset *out_off) { + RETURN_UNEXPECTED_IF_NULL(out_off); + const Tensor *ts = ts_ptr.get(); + auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); + const auto ptr = ts->GetBuffer(); + if (ptr == nullptr) { + RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); + } + auto src = ts->type().value(); + TensorType dest; +#define CASE(t) \ + case DataType::t: \ + dest = TensorType::TensorType_##t; \ + break + // Map the type to fill in the flat buffer. + switch (src) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + default: + MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; + RETURN_STATUS_UNEXPECTED("Unknown type"); + } +#undef CASE + + TensorMetaMsgBuilder ts_builder(*fbb_); + ts_builder.add_dims(shape_off); + ts_builder.add_type(dest); + auto ts_off = ts_builder.Finish(); + *out_off = ts_off; + return Status::OK(); +} + +Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, + std::shared_ptr *out) { + RETURN_UNEXPECTED_IF_NULL(col_ts); + auto shape_in = col_ts->dims(); + auto type_in = col_ts->type(); + std::vector v; + v.reserve(shape_in->size()); + v.assign(shape_in->begin(), shape_in->end()); + TensorShape shape(v); + DataType::Type dest = DataType::DE_UNKNOWN; +#define CASE(t) \ + case TensorType_##t: \ + dest = DataType::Type::t; \ + break + + switch (type_in) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + } +#undef CASE + + DataType type(dest); + std::shared_ptr ts = + std::make_shared(shape, type, static_cast(data.GetPointer()), data.GetSize()); + // Next we restore the real data which can be embedded or stored separately. + if (ts->SizeInBytes() != data.GetSize()) { + MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" + << "Dumping tensor\n" + << *ts << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + *out = std::move(ts); + return Status::OK(); +} + +Status BatchFetchRequest::RestoreRows(TensorTable *out) { + RETURN_UNEXPECTED_IF_NULL(out); + auto num_elements = row_id_.size(); + auto *offset_array = reinterpret_cast(mem_.GetPointer()); + TensorTable tbl; + tbl.reserve(num_elements); + ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); + for (auto i = 0; i < num_elements; ++i) { + auto len = offset_array[i + 1] - offset_array[i]; + TensorRow row; + row.setId(row_id_.at(i)); + if (len > 0) { + ReadableSlice row_data(all, offset_array[i], len); + // Next we de-serialize flat buffer to get back each column + auto msg = GetTensorRowHeaderMsg(row_data.GetPointer()); + auto msg_sz = msg->size_of_this(); + // Start of the tensor data + auto ts_offset = msg_sz; + row.reserve(msg->column()->size()); + for (auto k = 0; k < msg->column()->size(); ++k) { + auto col_ts = msg->column()->Get(k); + std::shared_ptr ts; + ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); + RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); + row.push_back(ts); + ts_offset += data.GetSize(); + } + } + tbl.push_back(std::move(row)); + } + *out = std::move(tbl); + return Status::OK(); +} + +Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map &map) { + try { + fbb_ = std::make_shared(); + std::vector> v; + v.reserve(map.size()); + for (auto &column : map) { + auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); + v.push_back(c); + } + auto v_off = fbb_->CreateVector(v); + auto final_off = CreateSchemaMsg(*fbb_, v_off); + fbb_->Finish(final_off); + buf_ = fbb_->GetBufferPointer(); + len_of_buf_ = fbb_->GetSize(); + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + +std::unordered_map FetchSchemaRequest::GetColumnMap() { + if (column_name_id_map_.empty()) { + auto *map_msg = flatbuffers::GetRoot(mem_.GetPointer()); + auto v = map_msg->column(); + for (auto i = 0; i < v->size(); ++i) { + auto col = map_msg->column()->Get(i); + column_name_id_map_.emplace(col->name()->str(), col->id()); + } + } + return column_name_id_map_; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/dataset/engine/cache/cache_request.h new file mode 100644 index 0000000000..3182816e54 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_request.h @@ -0,0 +1,225 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_ENGINE_CACHE_REQ_H_ +#define DATASET_ENGINE_CACHE_REQ_H_ + +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "dataset/core/tensor_row.h" +#include "dataset/util/slice.h" +#include "dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +/// \brief CacheClient communicates with CacheServer using Requests. +class BaseRequest { + public: + // Request types + enum class RequestType : int16_t { + kCacheRow = 0, + kBatchFetchRows = 1, + kCreateCache = 2, + kPurgeCache = 3, + kDestroyCache = 4, + kGetStat = 5, + kCacheSchema = 6, + kFetchSchema = 7, + kBuildPhaseDone = 8, + // Add new request before it. + kRequestUnknown = 32767 + }; + // For kCreateCache + enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; + friend class CacheServer; + /// \brief Base class of a cache server request + /// \param connection_id A combination of session id and crc that uniquely identifies a connection. + /// \param type Type of the request + explicit BaseRequest(connection_id_type connection_id, RequestType type) + : type_(type), connection_id_(connection_id) {} + virtual ~BaseRequest() = default; + /// \brief Wait for the completion of a request + /// \return Status returned from the cache server + Status Wait() { + RETURN_IF_NOT_OK(wp_.Wait()); + return rc_; + } + + /// \brief Getter function of the current connection id + /// \return Connection id + connection_id_type GetServerConnectionId() const { return connection_id_; } + + private: + RequestType type_; + connection_id_type connection_id_; + Status rc_; + WaitPost wp_; +}; +/// \brief Request to cache a single TensorRow +class CacheRowRequest : public BaseRequest { + public: + friend class CacheServer; + explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) + : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} + ~CacheRowRequest() = default; + + /// \brief Serialize a TensorRow for streaming to the cache server + /// \param row TensorRow + /// \return Status object + Status SerializeCacheRowRequest(const TensorRow &row); + /// \brief Return the row id assigned to this row for non-mappable dataset + /// \return row id of the cached row + row_id_type GetRowIdAfterCache() { return row_id_from_server_; } + + private: + std::shared_ptr fbb_; + row_id_type row_id_from_server_; + std::vector buffers_; + std::string cookie_; + + /// \brief Private function to serialize one TensorRow + /// \param row TensorRow + /// \return Status object + Status SerializeTensorRowHeader(const TensorRow &row); + /// \brief Private function to serialize one Tensor + /// \param ts_ptr Tensor + /// \return Status object + Status SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, flatbuffers::Offset *out_off); +}; +/// \brief Request to fetch rows in batch +class BatchFetchRequest : public BaseRequest { + public: + friend class CacheServer; + friend class CacheService; + BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id) + : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} + Status RestoreRows(TensorTable *out); + + private: + std::vector row_id_; + MemGuard mem_; + Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr *out); +}; +/// \brief Request to create a cache for the current connection +class CreationCacheRequest : public BaseRequest { + public: + friend class CacheServer; + /// \brief Constructor + /// \param connection_id + /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited + /// \param flag Attributes of the cache. + explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, + CreateCacheFlag flag = CreateCacheFlag::kNone) + : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} + + std::string cookie() const { return cookie_; } + + private: + uint64_t cache_mem_sz; + CreateCacheFlag flag_; + std::string cookie_; +}; +/// \brief Request to purge a cache. +class PurgeCacheRequest : public BaseRequest { + public: + friend class CacheServer; + explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} +}; +/// \brief Request to destroy a cache +class DestroyCacheRequest : public BaseRequest { + public: + friend class CacheServer; + explicit DestroyCacheRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kDestroyCache) {} +}; +/// \brief Obtain the statistics of the current connection +class GetStatRequest : public BaseRequest { + public: + friend class CacheServer; + friend class CacheService; + explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} + row_id_type GetMinRowId() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->min_row_id(); + } + row_id_type GetMaxRowId() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->max_row_id(); + } + int64_t GetNumMemCached() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->num_mem_cached(); + } + int64_t GetNumDiskCached() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->num_disk_cached(); + } + uint8_t GetState() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->state(); + } + + private: + MemGuard mem_; +}; +/// \brief Request to cache a schema +class CacheSchemaRequest : public BaseRequest { + public: + friend class CacheServer; + explicit CacheSchemaRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} + ~CacheSchemaRequest() = default; + + Status SerializeCacheSchemaRequest(const std::unordered_map &map); + const void *GetBuffer() const { return buf_; } + + private: + std::shared_ptr fbb_; + const void *buf_; + int64_t len_of_buf_; +}; +/// \brief Request to fetch a schema +class FetchSchemaRequest : public BaseRequest { + public: + friend class CacheServer; + explicit FetchSchemaRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kFetchSchema) {} + ~FetchSchemaRequest() = default; + + std::unordered_map GetColumnMap(); + + private: + MemGuard mem_; + std::unordered_map column_name_id_map_; +}; +/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. +class BuildPhaseDoneRequest : public BaseRequest { + public: + friend class CacheServer; + BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) + : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} + + private: + std::string cookie_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/dataset/engine/cache/cache_server.cc new file mode 100644 index 0000000000..88d617b598 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_server.cc @@ -0,0 +1,252 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "dataset/engine/cache/cache_server.h" +#include "dataset/engine/cache/cache_service.h" +#include "dataset/engine/cache/cache_request.h" +#include "dataset/util/bit.h" + +namespace mindspore { +namespace dataset { +Status CacheServer::DoServiceStart() { + if (!top_.empty()) { + Path spill(top_); + RETURN_IF_NOT_OK(spill.CreateDirectories()); + MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; + } + RETURN_IF_NOT_OK(vg_.ServiceStart()); + cache_q_ = std::make_shared>(1024); + RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); + auto f = std::bind(&CacheServer::ServerRequest, this); + // Spawn a a few threads to serve the request. + for (auto i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); + } + return Status::OK(); +} + +Status CacheServer::DoServiceStop() { + Status rc; + Status rc2; + // First stop all the threads. + RETURN_IF_NOT_OK(vg_.ServiceStop()); + // Clean up all the caches if any. + UniqueLock lck(&rwLock_); + auto it = all_caches_.begin(); + while (it != all_caches_.end()) { + auto cs = std::move(it->second); + rc2 = cs->ServiceStop(); + if (rc2.IsError()) { + rc = rc2; + } + ++it; + } + return rc; +} + +CacheService *CacheServer::GetService(connection_id_type id) const { + SharedLock lck(&rwLock_); + auto it = all_caches_.find(id); + if (it != all_caches_.end()) { + return it->second.get(); + } + return nullptr; +} + +Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, + BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { + // We can't do spilling unless this server is setup with a spill path in the first place + bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; + bool generate_id = + (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; + if (spill && top_.empty()) { + RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); + } + RETURN_UNEXPECTED_IF_NULL(out_cookie); + *out_cookie = ""; + // Before creating the cache, first check if this is a request for a shared usage of an existing cache + // If two CreateService come in with identical connection_id, we need to serialize the create. + // The first create will be successful and be given a special cookie. + UniqueLock lck(&rwLock_); + auto end = all_caches_.end(); + auto it = all_caches_.find(connection_id); + if (it == end) { + std::unique_ptr cs; + try { + cs = std::make_unique(cache_mem_sz, spill ? top_ : "", generate_id); + RETURN_IF_NOT_OK(cs->ServiceStart()); + *out_cookie = cs->cookie(); + all_caches_.emplace(connection_id, std::move(cs)); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } + } else { + MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; + // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it + // treat it as OK. + return Status(StatusCode::kDuplicateKey); + } + return Status::OK(); +} + +/// This is the main loop the cache server thread(s) are running. +/// Each thread will pop a request and save the result in the same request. +/// The sender will wait on the wait post in the request. Once the request +/// is fulfilled, the server thread will do a post signalling the request is +/// is processed. +/// \return +Status CacheServer::ServerRequest() { + TaskManager::FindMe()->Post(); + // Loop forever until we are interrupted. + while (true) { + BaseRequest *base_rq = nullptr; + RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); + auto cs = GetService(base_rq->connection_id_); + // Except for creating a new session, we expect cs is not null. + switch (base_rq->type_) { + case BaseRequest::RequestType::kCacheRow: { + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + // Only if the cookie matches, we can accept insert into this cache that has a build phase + if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { + rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + break; + } + case BaseRequest::RequestType::kBatchFetchRows: { + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); + } + break; + } + case BaseRequest::RequestType::kCreateCache: { + // If the cache is already created we still need to run the creation so that we do sanity checks on the + // client id and return the cache id back to the user. + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); + break; + } + case BaseRequest::RequestType::kPurgeCache: { + if (cs != nullptr) { + base_rq->rc_ = cs->Purge(); + } else { + // it is already purged. Ignore it. + base_rq->rc_ = Status::OK(); + } + break; + } + case BaseRequest::RequestType::kDestroyCache: { + if (cs != nullptr) { + // We need a strong lock to protect the map. + connection_id_type id = base_rq->connection_id_; + UniqueLock lck(&rwLock_); + // std::map will invoke the constructor of CacheService. So we don't need to do anything here. + auto n = all_caches_.erase(id); + if (n == 0) { + // It has been destroyed by another duplicate request. + MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; + } + base_rq->rc_ = Status::OK(); + } else { + // it is already destroyed. Ignore it. + base_rq->rc_ = Status::OK(); + } + break; + } + case BaseRequest::RequestType::kGetStat: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + CacheService::ServiceStat svc_stat; + rq->rc_ = cs->GetStat(&svc_stat); + if (rq->rc_.IsOk()) { + flatbuffers::FlatBufferBuilder fbb; + ServiceStatMsgBuilder bld(fbb); + bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); + bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); + bld.add_max_row_id(svc_stat.max_); + bld.add_min_row_id(svc_stat.min_); + bld.add_state(svc_stat.state_); + auto offset = bld.Finish(); + fbb.Finish(offset); + rq->rc_ = rq->mem_.allocate(fbb.GetSize()); + if (rq->rc_.IsOk()) { + WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); + ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); + RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); + } + } + } + break; + } + case BaseRequest::RequestType::kCacheSchema: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); + } + break; + } + case BaseRequest::RequestType::kFetchSchema: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->FetchSchema(&rq->mem_); + } + break; + } + case BaseRequest::RequestType::kBuildPhaseDone: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + // We can only allow to switch phase is the cookie match. + if (rq->cookie_ == cs->cookie()) { + rq->rc_ = cs->BuildPhaseDone(); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + break; + } + default: + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); + } + // Notify it is done, and move on to the next request. + base_rq->wp_.Set(); + } + return Status::OK(); +} +CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) + : top_(spill_path), num_workers_(num_workers) {} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/dataset/engine/cache/cache_server.h new file mode 100644 index 0000000000..f83fa1cb6d --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_server.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_CACHE_SERVER_H_ +#define DATASET_ENGINE_CACHE_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "dataset/engine/cache/cache_service.h" +#include "dataset/core/tensor.h" +#include "dataset/util/arena.h" +#include "dataset/util/cache_pool.h" +#include "dataset/util/lock.h" +#include "dataset/util/service.h" +#include "dataset/util/services.h" +#include "dataset/util/system_pool.h" +#include "dataset/util/queue.h" +#include "dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +class BaseRequest; +/// \brief A server which provides CacheService services. +class CacheServer : public Service { + public: + friend class Services; + using cache_index = std::map>; + + CacheServer(const CacheServer &) = delete; + CacheServer &operator=(const CacheServer &) = delete; + CacheServer(CacheServer &&) = delete; + CacheServer &operator=(CacheServer &) = delete; + static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } + Status DoServiceStart() override; + Status DoServiceStop() override; + ~CacheServer() { (void)ServiceStop(); } + + /// \brief For the current demonstration, a cache client contacts cache server using a Queue. + /// \param rq + /// \return Status object + Status PushRequest(BaseRequest *rq) { + RETURN_UNEXPECTED_IF_NULL(rq); + RETURN_IF_NOT_OK(cache_q_->Add(rq)); + return Status::OK(); + } + + private: + mutable RWLock rwLock_; + std::string top_; + cache_index all_caches_; + std::shared_ptr> cache_q_; + TaskGroup vg_; + int32_t num_workers_; + + /// \brief Constructor + /// \param spill_path Top directory for spilling buffers to. + /// \param num_workers Number of threads for handling requests. + explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); + + /// \brief Locate a cache service from connection id. + /// \return Pointer to cache service. Null if not found + CacheService *GetService(connection_id_type id) const; + + /// \brief Create a cache service. We allow multiple clients to create the same cache service. + /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given + /// a special unique cookie. + /// \param[in] connection_id This is from a Cache client. + /// \param[in] cache_mem_sz + /// \param[in] flag + /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator + /// \return Status object + Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, + std::string *out_cookie); + + /// \brief Entry point for all server threads. + Status ServerRequest(); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_CACHE_TENSOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/dataset/engine/cache/cache_service.cc new file mode 100644 index 0000000000..1cbe3fdb4e --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_service.cc @@ -0,0 +1,265 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "dataset/engine/cache/cache_service.h" +#include "dataset/util/slice.h" + +namespace mindspore { +namespace dataset { +CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) + : root_(root), + cache_mem_sz_(mem_sz), + cp_(nullptr), + map_(nullptr), + next_id_(0), + generate_id_(generate_id), + schema_key_(-1), + st_(generate_id ? State::kBuildPhase : State::kNone) {} +CacheService::~CacheService() { (void)ServiceStop(); } +bool CacheService::UseArena() { + // If fixed size, use Arena instead of the pool from global context. + return (cache_mem_sz_ > 0); +} +Status CacheService::DoServiceStart() { + std::shared_ptr mp_; + if (UseArena()) { + // Create a fixed size arena based on the parameter. + std::shared_ptr arena; + RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); + mp_ = std::move(arena); + } else { + // Unlimited size. Simply use a system pool. Another choice is CircularPool. + mp_ = std::make_shared(); + } + // Put together a CachePool for backing up the Tensor + cp_ = std::make_shared(CachePool::value_allocator(mp_), root_); + RETURN_IF_NOT_OK(cp_->ServiceStart()); + // Set up the B+ tree as well. But use the system pool instead. + map_ = std::make_shared(); + // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. + cookie_ = cp_->MyName(); + return Status::OK(); +} +Status CacheService::DoServiceStop() { + if (cp_ != nullptr) { + RETURN_IF_NOT_OK(cp_->ServiceStop()); + } + return Status::OK(); +} +Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated) { + SharedLock rw(&rw_lock_); + RETURN_UNEXPECTED_IF_NULL(row_id_generated); + if (st_ == State::kFetchPhase) { + // For this kind of cache service, once we are done with the build phase into fetch phase, we can't + // allow other to cache more rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + try { + // The first buffer is a flatbuffer which describes the rest of the buffers follow + auto fb = buf.front(); + RETURN_UNEXPECTED_IF_NULL(fb); + auto msg = GetTensorRowHeaderMsg(fb); + // If the server side is designed to ignore incoming row id, we generate row id. + if (generate_id_) { + *row_id_generated = GetNextRowId(); + // Some debug information on how many rows we have generated so far. + if ((*row_id_generated) % 1000 == 0) { + MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; + } + } else { + if (msg->row_id() < 0) { + std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + *row_id_generated = msg->row_id(); + } + auto size_of_this = msg->size_of_this(); + auto column_hdr = msg->column(); + // Number of tensor buffer should match the number of columns plus one. + if (buf.size() != column_hdr->size() + 1) { + std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) + + " but get " + std::to_string(buf.size()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + // Next we store in either memory or on disk. Low level code will consolidate everything in one piece. + std::vector all_data; + all_data.reserve(column_hdr->size() + 1); + all_data.emplace_back(fb, size_of_this); + for (auto i = 0; i < column_hdr->size(); ++i) { + all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); + } + // Now we cache the flat buffer. + CachePool::key_type key; + RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); + Status rc = map_->DoInsert(*row_id_generated, key); + if (rc == Status(StatusCode::kDuplicateKey)) { + MS_LOG(DEBUG) << "Ignoring duplicate key"; + } else { + RETURN_IF_NOT_OK(rc); + } + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } +} +std::ostream &operator<<(std::ostream &out, const CacheService &cs) { + // Then show any custom derived-internal stuff + out << "\nCache memory size: " << cs.cache_mem_sz_; + out << "\nSpill path: "; + if (cs.root_.empty()) { + out << "None"; + } else { + out << cs.GetSpillPath(); + } + return out; +} +Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } +Status CacheService::Purge() { + // First we must lock exclusively. No one else can cache/restore anything. + UniqueLock rw(&rw_lock_); + RETURN_IF_NOT_OK(cp_->ServiceStop()); + auto new_map = std::make_shared(); + map_.reset(); + map_ = std::move(new_map); + next_id_ = 0; + RETURN_IF_NOT_OK(cp_->ServiceStart()); + return Status::OK(); +} +Status CacheService::GetStat(CacheService::ServiceStat *out) { + SharedLock rw(&rw_lock_); + RETURN_UNEXPECTED_IF_NULL(out); + if (st_ == State::kNone || st_ == State::kFetchPhase) { + out->stat_ = cp_->GetStat(); + out->state_ = static_cast(st_); + auto it = map_->begin(); + if (it != map_->end()) { + out->min_ = it.key(); + auto end_it = map_->end(); + --end_it; + out->max_ = end_it.key(); + } + } else { + out->state_ = static_cast(st_); + } + return Status::OK(); +} +Status CacheService::BatchFetch(const std::vector &v, MemGuard *out) const { + RETURN_UNEXPECTED_IF_NULL(out); + SharedLock rw(&rw_lock_); + if (st_ == State::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + const auto num_elements = v.size(); + int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); + int64_t data_offset = mem_sz; + std::vector sz_v; + std::vector keys; + sz_v.reserve(num_elements); + keys.reserve(num_elements); + for (auto row_id : v) { + auto r = map_->Search(row_id); + if (r.second) { + auto &it = r.first; + CachePool::key_type key = it.value(); + auto sz = cp_->GetSize(key); + if (sz == 0) { + std::string errMsg = "Key not found: "; + errMsg += std::to_string(key); + RETURN_STATUS_UNEXPECTED(errMsg); + } + keys.push_back(key); + sz_v.push_back(sz); + mem_sz += sz; + } else { + keys.push_back(-1); + sz_v.push_back(0); + } + } + MemGuard mem; + RETURN_IF_NOT_OK(mem.allocate(mem_sz)); + auto *offset_array = reinterpret_cast(mem.GetMutablePointer()); + offset_array[0] = data_offset; + WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); + for (auto i = 0; i < num_elements; ++i) { + auto sz = sz_v.at(i); + offset_array[i + 1] = offset_array[i] + sz; + if (sz > 0) { + WritableSlice row_data(all, offset_array[i], sz); + auto key = keys.at(i); + size_t bytesRead = 0; + RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); + if (bytesRead != sz) { + MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." + << " Internal key: " << key << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + } + } + *out = std::move(mem); + return Status::OK(); +} +Status CacheService::CacheSchema(const void *buf, int64_t len) { + SharedLock rw(&rw_lock_); + if (st_ == State::kFetchPhase) { + // For this kind of cache service, once we are done with the build phase into fetch phase, we can't + // allow other to cache more rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + // This is a special request and we need to remember where we store it. + // In case we are calling the same function from multiple threads, only + // the first one is considered. Rest is ignored. + CachePool::key_type cur_key = schema_key_; + CachePool::key_type key; + if (cur_key < 0) { + RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); + auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); + MS_LOG(DEBUG) << "Caching Schema. Result = " << result; + } else { + MS_LOG(DEBUG) << "Caching Schema already done"; + } + return Status::OK(); +} +Status CacheService::FetchSchema(MemGuard *out) const { + SharedLock rw(&rw_lock_); + if (st_ == State::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + RETURN_UNEXPECTED_IF_NULL(out); + MemGuard mem; + if (schema_key_ >= 0) { + auto len = cp_->GetSize(schema_key_); + RETURN_IF_NOT_OK(mem.allocate(len)); + auto slice = WritableSlice(mem.GetMutablePointer(), len); + RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); + *out = std::move(mem); + } else { + return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); + } + return Status::OK(); +} +Status CacheService::BuildPhaseDone() { + if (HasBuildPhase()) { + // Exclusive lock to switch phase + UniqueLock rw(&rw_lock_); + st_ = State::kFetchPhase; + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/dataset/engine/cache/cache_service.h new file mode 100644 index 0000000000..60cfa40a50 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/cache_service.h @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_CACHE_SERVICE_H_ +#define DATASET_ENGINE_CACHE_SERVICE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "dataset/core/global_context.h" +#include "dataset/core/tensor.h" +#include "dataset/engine/cache/cache_request.h" +#include "dataset/util/arena.h" +#include "dataset/util/btree.h" +#include "dataset/util/cache_pool.h" +#include "dataset/util/service.h" +#include "dataset/util/services.h" +#include "dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +struct CacheStat; +/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is +/// created to support spilling +class CacheService : public Service { + public: + friend class CacheServer; + using row_map = BPlusTree; + + enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; + + /// \brief Constructor + /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited + /// \param root Spill path. Empty string means no spilling + /// \param generate_id If the cache service should generate row id for buffer that is cached. + /// For non-mappable dataset, this should be set to true. + CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); + ~CacheService(); + + /// \brief For fixed size memory, we will create an Arena. + /// \return false if unlimited memory. + bool UseArena(); + + Status DoServiceStart() override; + Status DoServiceStop() override; + + /// \brief Main function to cache a row which is in form a series of buffers. + /// The first buffer is a Google flatbuffer which describes the rest of the buffers followed. + /// \param[in] buf Vector of buffer + /// \param[out] row_id_generated The row id assigned to this row if any + /// \return Status object + Status CacheRow(const std::vector &buf, row_id_type *row_id_generated); + /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded + /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. + /// \param[in] v A vector of row id. + /// \param[out] out A contiguous memory buffer that holds the requested rows. + /// \return Status object + Status BatchFetch(const std::vector &v, MemGuard *out) const; + + /// \brief Getter function + /// \return Spilling path + Path GetSpillPath() const; + /// \brief A structure returned from the cache server for statistics request. + class ServiceStat { + public: + using state_type = std::underlying_type::type; + ServiceStat() : min_(0), max_(0), state_(0) {} + CachePool::CacheStat stat_{}; + row_id_type min_; + row_id_type max_; + state_type state_; + }; + /// \brief Statistics for the current service + /// \param[in/out] A pointer to a pre-allocated ServiceStat structure + /// \return Status Object + Status GetStat(ServiceStat *); + /// \brief Cache schema + /// \param buf A Google Flatbuffer that contains the schema + /// \param len size of the buffer + /// \return Status object + Status CacheSchema(const void *buf, int64_t len); + /// \brief Fetch schema + /// \param out A contiguous memory that contains the serialized form of schema. + /// \return Status object + Status FetchSchema(MemGuard *out) const; + /// \brief Purge the content of a cache + /// \return Status object + Status Purge(); + /// \brief Overload the << operator to print a cache service + /// \param out std::ostream + /// \param cs A cache service + /// \return std::ostream + friend std::ostream &operator<<(std::ostream &out, const CacheService &cs); + /// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient + /// is the creator + /// \return Cookie + std::string cookie() const { return cookie_; } + /// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and + /// a read phase. + /// \return True if has two phases. + bool HasBuildPhase() const { return generate_id_; } + /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. + /// \return Status object + Status BuildPhaseDone(); + + private: + mutable RWLock rw_lock_; + std::string root_; + uint64_t cache_mem_sz_; + std::shared_ptr cp_; + std::shared_ptr map_; + std::atomic next_id_; + bool generate_id_; + std::atomic schema_key_; + std::string cookie_; + State st_; + + /// \brief Private function to generate a row id + /// \return Row id assigned. + row_id_type GetNextRowId() { return next_id_.fetch_add(1); } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs b/mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs new file mode 100644 index 0000000000..de26069f23 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +namespace mindspore.dataset; + +/// Type of a Tensor +enum TensorType : byte { + DE_UNKNOWN = 0, + DE_BOOL = 1, + DE_INT8 = 2, + DE_UINT8 = 3, + DE_INT16 = 4, + DE_UINT16 = 5, + DE_INT32 = 6, + DE_UINT32 = 7, + DE_INT64 = 8, + DE_UINT64 = 9, + DE_FLOAT16 = 10, + DE_FLOAT32 = 11, + DE_FLOAT64 = 12, + DE_STRING = 13 +} + +/// The meta information of a Tensor +/// \note Only the type and shape are considered meta information. Tensor data is excluded. +table TensorMetaMsg { + dims:[int64] (required); + type:TensorType; +} + +/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized. +/// \param row_id is the row id of the TensorRow. +/// \param column The meta information of each Tensor in the row +/// \param size of this serialized buffer +/// \param size of each tensor data buffer that follows +table TensorRowHeaderMsg { + row_id:int64; + column:[TensorMetaMsg] (required); + size_of_this:int64; + data_sz:[int64] (required); +} + +root_type TensorRowHeaderMsg; + +/// A row of row id's +table TensorRowIds { + row_id:[int64] (required); +} + +/// Statistics returned from each cache service +/// \note It must match CacheService::ServiceStat +table ServiceStatMsg { + num_mem_cached:int64; + num_disk_cached:int64; + min_row_id:int64; + max_row_id:int64; + state:int8; +} + +/// Column description of each column in a schema +table ColumnNameMsg { + name:string; + id:int32; +} + +/// Serialized form of a schema +table SchemaMsg { + column:[ColumnNameMsg]; +} diff --git a/mindspore/ccsrc/dataset/engine/data_buffer.cc b/mindspore/ccsrc/dataset/engine/data_buffer.cc index 32a70c259f..718721b906 100644 --- a/mindspore/ccsrc/dataset/engine/data_buffer.cc +++ b/mindspore/ccsrc/dataset/engine/data_buffer.cc @@ -24,10 +24,8 @@ namespace dataset { // Description: This is the main constructor that is used for making a buffer DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {} -// Name: print() -// Description: A function that prints info about the DataBuffer (base class version) -void DataBuffer::Print(std::ostream &out, // In: The output stream to print to - bool show_all) const { // In: T/F if it should show everything +// A method for debug printing of the buffer +void DataBuffer::Print(std::ostream &out, bool show_all) const { out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n"; // If the column counts are set then it means that data has been set into @@ -46,11 +44,6 @@ void DataBuffer::Print(std::ostream &out, // In: The output stream to print } } -Status DataBuffer::Load() { - std::string err_msg = "Base class load called, but it does not have an implementation!"; - RETURN_STATUS_UNEXPECTED(err_msg); -} - // Remove me!! Callers should fetch rows via pop Status DataBuffer::GetTensor(std::shared_ptr *ptr, int32_t row_id, int32_t col_id) const { if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) { @@ -92,8 +85,5 @@ Status DataBuffer::SliceOff(int64_t number_of_rows) { return Status::OK(); } - -// Destructor -DataBuffer::~DataBuffer() {} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/data_buffer.h b/mindspore/ccsrc/dataset/engine/data_buffer.h index 2ab0783519..b539bdaf7b 100644 --- a/mindspore/ccsrc/dataset/engine/data_buffer.h +++ b/mindspore/ccsrc/dataset/engine/data_buffer.h @@ -29,11 +29,9 @@ namespace mindspore { namespace dataset { -// The DataBuffer class is a base class that will represent the data for n values based -// on a unique row id for each row of data. -// There can be different types of DataBuffers to abstract over how the data is stored -// in memory and acquired from storage. -// Each buffer holds a range of consecutive row id's. +/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between +/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format +/// where n TensorRows may consist of m tensors (columns). class DataBuffer { public: // Buffer flags @@ -47,13 +45,13 @@ class DataBuffer { // Description: This is the main constructor that is used for making a buffer DataBuffer(int32_t id, BufferFlags flags); - // Destructor - virtual ~DataBuffer(); + /// \brief default destructor + ~DataBuffer() = default; - // Name: print() - // Description: A function that prints info about the DataBuffer (base class version) - virtual void Print(std::ostream &out, // In: The output stream to print to - bool show_all) const; // In: T/F if it should show everything + /// \brief A method for debug printing of the buffer + /// \param[inout] out The stream to write to + /// \param[in] show_all A boolean to toggle between details and summary printing + void Print(std::ostream &out, bool show_all) const; // Provide stream operator for displaying it friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) { @@ -61,10 +59,6 @@ class DataBuffer { return out; } - // Name: load() - // Description: populates the DataBuffer with data based on it's id - virtual Status Load(); - // Convenience getter functions for flag checking bool eof() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOF)); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 2dbdb82d26..a2cd6dc07a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -17,7 +17,11 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES take_op.cc shuffle_op.cc zip_op.cc - concat_op.cc + concat_op.cc + cache_base_op.cc + cache_lookup_op.cc + cache_op.cc + cache_merge_op.cc ) if (ENABLE_PYTHON) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc new file mode 100644 index 0000000000..42d3f0fee3 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/cache_base_op.h" +#include +#include +#include "dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +// A print method typically used for debugging +void CacheBase::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCache client:\n" << *cache_client_ << "\n\n"; + } +} +// Overrides base class reset method. When an operator does a reset, it cleans up any state +// info from it's previous execution and then initializes itself so that it can be executed +// again. +Status CacheBase::Reset() { + if (sampler_ != nullptr) { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + } + // Wake up the workers to get them going again in a new epoch + MS_LOG(DEBUG) << Name() << " resetting."; + epoch_sync_.Set(); + return Status::OK(); +} +CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, sampler), + cache_client_(cache_client), + rows_per_buffer_(rows_per_buf), + // We can cause deadlock if this internal Connector size is too small. + keys_miss_(num_workers_, 1, 1024) { + io_block_queues_.Init(num_workers, op_connector_size); +} +// Common function to fetch samples from the sampler and send them using the io_block_queues to +// the parallel workers +Status CacheBase::FetchSamplesToWorkers() { + int64_t buf_cnt = 0; + int64_t wait_cnt = 0; + do { + epoch_sync_.Clear(); + std::vector keys; + int64_t row_cnt = 0; + keys.reserve(rows_per_buffer_); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (!sampler_buffer->eoe()) { + TensorRow sample_row; + RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { + keys.push_back(*itr); + ++row_cnt; + if (row_cnt % rows_per_buffer_ == 0) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (!keys.empty()) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + } + // send the eoe + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + // If repeat but the not last repeat, wait for reset. + if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; + RETURN_IF_NOT_OK(epoch_sync_.Wait()); + } else { + // We can break out from the loop. + break; + } + } while (true); + // Flow the eof before exit + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + // Ask all the workers to quit. + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); +} +Status CacheBase::FetchFromCache(int32_t worker_id) { + int64_t buffer_id = worker_id; + std::unique_ptr blk; + do { + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); + if (blk->eof()) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else if (blk->eoe()) { + if (AllowCacheMiss()) { + // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from + // a sampler, send a eoe to physical leaf op as well. + std::vector eoe; + eoe.push_back(eoe_row_id); + RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(blk->GetKeys(&keys)); + if (keys.empty()) { + // empty key is a quit signal for workers + break; + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + std::unique_ptr que = std::make_unique(); + TensorTable ttbl; + RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); + auto row_it = ttbl.begin(); + std::vector cache_miss; + cache_miss.reserve(keys.size()); + for (auto row_id : keys) { + auto &row = *row_it; + if (row.empty()) { + if (AllowCacheMiss()) { + cache_miss.push_back(row_id); + } else { + std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; + RETURN_STATUS_UNEXPECTED(errMsg); + } + } + que->push_back(std::move(row)); + ++row_it; + } + db->set_tensor_table(std::move(que)); + if (AllowCacheMiss()) { + // Because of the way connector works, we push unconditionally even cache_miss can be empty. + RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + } while (true); + return Status::OK(); +} +Status CacheBase::RegisterResources() { + RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + return Status::OK(); +} +CacheBase::~CacheBase() {} +Status CacheBase::UpdateColumnMapFromCache() { + Status rc; + // Get the schema from the server. It may not be there yet. So tolerate the error. + if (column_name_id_map_.empty()) { + rc = cache_client_->FetchSchema(&column_name_id_map_); + if (rc == Status(StatusCode::kFileNotExist)) { + MS_LOG(DEBUG) << "Schema not in the server yet."; + rc = Status::OK(); + } + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h new file mode 100644 index 0000000000..a6a98fc4ad --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ + +#include +#include +#include +#include +#include "dataset/engine/cache/cache_client.h" +#include "dataset/engine/cache/cache_service.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/datasetops/repeat_op.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/engine/datasetops/source/sampler/sampler.h" +#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "dataset/util/queue.h" +#include "dataset/util/wait_post.h" +#include "dataset/engine/datasetops/cache_base_op.h" +namespace mindspore { +namespace dataset { +/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. +/// \see CacheOp +/// \see CacheLookupOp +class CacheBase : public ParallelOp { + public: + /// \brief Base class constructor + /// \param num_workers Number of parallel workers + /// \param op_connector_size Connector size + /// \param rows_per_buf Number of rows per buffer + /// \param cache_client CacheClient for communication to the CacheServer + /// \param sampler Sampler which is mandatory + CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler); + /// \brief Destructor + ~CacheBase(); + + constexpr static int eoe_row_id = -1; + + /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state + /// info from it's previous execution and then initializes itself so that it can be executed + /// again. + /// \return Status - The error code return + Status Reset() override; + + /// \brief A print method typically used for debugging + /// \param out The output stream to write output to + /// \param show_all A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + /// \brief << Stream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param out reference to the output stream being overloaded + /// \param mo reference to the CacheOp to display + /// \return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) { + mo.Print(out, false); + return out; + } + + /// \brief Getter for the cache client + /// \return shared ptr to the cache client + std::shared_ptr cache_client() { return cache_client_; } + /// \brief Setter for the cache client + void SetCacheClient(std::shared_ptr cache_client) { cache_client_ = std::move(cache_client); } + /// \brief Derived class must implement this method if a cache miss is treated as error + virtual bool AllowCacheMiss() = 0; + + protected: + std::shared_ptr cache_client_; + WaitPost epoch_sync_; + int32_t rows_per_buffer_; + Connector> keys_miss_; + + /// \brief Common function to register resources for interrupt + /// \note Derived should override this function for extra resources to be registered + virtual Status RegisterResources(); + /// \brief This function is called by main thread to send samples to the worker thread. + /// \note It is a non-virtual function + /// \return Status object + Status FetchSamplesToWorkers(); + /// \brief This function is called by each worker to fetch rows from the cache server for a given set of + /// sample row id's + /// \return Status object + Status FetchFromCache(int32_t worker_id); + /// \brief Get the column map from cache server + Status UpdateColumnMapFromCache(); + + private: + QueueList> io_block_queues_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc new file mode 100644 index 0000000000..196a8790df --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/cache_lookup_op.h" +#include "dataset/engine/opt/pass.h" +#include "dataset/core/config_manager.h" +#include "dataset/core/constants.h" +#include "dataset/core/global_context.h" +#include "dataset/engine/execution_tree.h" +#include "utils/log_adapter.h" +#include "utils/system/crc32c.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status CacheLookupOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Cache client for CacheLookupOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheLookupOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, + build_cache_client_, build_sampler_); + return Status::OK(); +} +Status CacheLookupOp::operator()() { + if (!sampler_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "CacheLookupOp requires a sampler before it can be executed!"); + } + RETURN_IF_NOT_OK(RegisterResources()); + // Kick off the workers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1))); + // required task group sync after launching workers + TaskManager::FindMe()->Post(); + // We have to wait until the leaf op has handshake with us. + RETURN_IF_NOT_OK(leaf_op_wp_.Wait()); + RETURN_IF_NOT_OK(FetchSamplesToWorkers()); + return Status::OK(); +} +Status CacheLookupOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(FetchFromCache(worker_id)); + return Status::OK(); +} +Status CacheLookupOp::ResetSampler() { return Status::OK(); } +Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { + // We act like a sampler and as a dataset op. During handshake with leaf op, + // We must wait until the leaf op has indexed everything. + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op)); + // Now we notify the main thread handshake has finished. + leaf_op_wp_.Set(); + return Status::OK(); +} +Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } +void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } +Status CacheLookupOp::GetNextSample(std::unique_ptr *out_buffer) { + std::vector cache_miss; + RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + // Ignore the case we have no cache miss, we can't return empty samples. + while (cache_miss.empty()) { + RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + } + // Special code for eoe + if (cache_miss.at(0) == eoe_row_id) { + *out_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + std::shared_ptr sample_ts; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size())); + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + auto idPtr = sample_ts->begin(); + for (auto i = 0; i < cache_miss.size(); ++i) { + *idPtr = cache_miss.at(i); + ++idPtr; + } + TensorRow row; + row.push_back(sample_ts); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} +Status CacheLookupOp::RegisterResources() { + RETURN_IF_NOT_OK(CacheBase::RegisterResources()); + RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks())); + return Status::OK(); +} +Status CacheLookupOp::ComputeColMap() { + // We don't know the column map at this point unless we contact the cache server + // to fetch the schema but the cache server may not have it at this point either. + // So we will just return OK and let MergeOp (our parent) to handle it. + return Status::OK(); +} + +// Visitor accept method for NodePass +Status CacheLookupOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h new file mode 100644 index 0000000000..526fb7c3a7 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ + +#include +#include +#include +#include +#include +#include "dataset/engine/datasetops/cache_base_op.h" + +namespace mindspore { +namespace dataset { +/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. +/// \note For non-mappable dataset, please see CacheOp +/// \see CacheOp +class CacheLookupOp : public CacheBase, public Sampler { + public: + class Builder { + public: + /// \brief Builder constructor. Creates the builder object. + /// \note No default args + Builder(); + + /// Default destructor + ~Builder() = default; + + /// Setter method. + /// \treturn Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheLookupOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t rows_per_buffer_; + int32_t build_op_connector_size_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + // Check if the required parameters are set by the builder. + // \return Status The error code return + Status SanityCheck() const; + }; + /// \brief Constructor + /// \note It takes the same argument as the base class. + /// \see CacheBase + CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} + ~CacheLookupOp() = default; + // As a parallel op, we override these two functions + Status operator()() override; + Status WorkerEntry(int32_t worker_id) override; + // As a sampler, we override the following functions + Status ResetSampler() override; + Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + Status InitSampler() override; + Status GetNextSample(std::unique_ptr *out_buffer) override; + void Print(std::ostream &out, bool show_all) const override; + bool AllowCacheMiss() override { return true; } + std::string Name() const override { return "CacheLookupOp"; } + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + protected: + Status ComputeColMap() override; + + private: + WaitPost leaf_op_wp_; + + Status RegisterResources() override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc new file mode 100644 index 0000000000..5d00ec071f --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "dataset/core/config_manager.h" +#include "dataset/core/constants.h" +#include "dataset/core/global_context.h" +#include "dataset/engine/datasetops/cache_merge_op.h" +#include "dataset/engine/opt/pass.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +CacheMergeOp::~CacheMergeOp() = default; +void CacheMergeOp::Print(std::ostream &out, bool show_all) + const { // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\n\n"; + } +} +CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, + std::shared_ptr cache_client, const std::shared_ptr &sampler) + : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} +Status CacheMergeOp::operator()() { + // A queue of row id to let cleaner send cache miss rows to the cache server + // We don't want a small queue as this will block the parallel op workers. + // A row id is 8 byte integer. So bigger size doesn't consume a lot of memory. + io_que_ = std::make_unique>(512); + RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1))); + // One dedicated thread to move TensorRow from the pool to the cache server + for (auto i = 0; i < num_cleaners_; ++i) { + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this))); + } + TaskManager::FindMe()->Post(); + return Status::OK(); +} +// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait +// until it shows up in the pool. +Status CacheMergeOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::shared_ptr cache_hit_stream = child_[kCacheHitChildIdx]; + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + while (!db_ptr->eof()) { + if (db_ptr->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + db_ptr.reset(); + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + } else { + // See if there is any missing row + auto tbl = std::make_unique(); + while (db_ptr->NumRows() > 0) { + TensorRow row; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + if (row.empty()) { + auto row_id = row.getId(); + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + // Block until the row shows up in the pool. + RETURN_IF_NOT_OK(rq->Wait(&row)); + } + tbl->push_back(std::move(row)); + } + db_ptr->set_tensor_table(std::move(tbl)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + } + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); + return Status::OK(); +} +Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { + TaskManager::FindMe()->Post(); + // We will simply pop TensorRow from the stream and insert them into the pool and + // wake up any worker that is awaiting on the missing TensorRow. + // If we see an eoe, ignore it. For eof, we exit. + std::shared_ptr cache_missing_stream = child_[kCacheMissChildIdx]; + // Before we start, cache the schema at the server. Pick one of the workers + // do it. The schema should have been done at prepare time. + if (workerId == 0) { + RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); + } + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); + while (!db_ptr->eof()) { + if (db_ptr->eoe()) { + // Ignore it. + MS_LOG(DEBUG) << "Ignore eoe"; + } else { + while (db_ptr->NumRows() > 0) { + TensorRow row; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + row_id_type row_id = row.getId(); + if (row_id < 0) { + std::string errMsg = "Expect positive row id: " + std::to_string(row_id); + RETURN_STATUS_UNEXPECTED(errMsg); + } + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + rq->WakeUpAny(std::move(row)); + // Let the cleaner to flush out this row (async) to the cache server. + RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + } + } + RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); + } + return Status::OK(); +} +Status CacheMergeOp::Cleaner() { + TaskManager::FindMe()->Post(); + while (true) { + row_id_type row_id; + RETURN_IF_NOT_OK(io_que_->PopFront(&row_id)); + if (row_id < 0) { + break; + } + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + if (rq->GetState() == TensorRowRequest::State::kClean) { + // If already flushed, move on to the next one. + continue; + } + TensorRow row; + RETURN_IF_NOT_OK(rq->Release(&row)); + CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error"); + Status rc = cache_client_->WriteRow(row); + // Bad rc should not bring down the pipeline + if (rc.IsError()) { + MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); + } + rq->SetState(TensorRowRequest::State::kClean); + } + return Status::OK(); +} + +Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { + RETURN_UNEXPECTED_IF_NULL(out); + std::unique_lock lck(mux_); + auto it = cache_miss_map_.find(row_id); + if (it != cache_miss_map_.end()) { + *out = it->second.GetMutablePointer(); + } else { + // We will create a new one. + auto alloc = Services::GetAllocator(); + auto r = cache_miss_map_.emplace(row_id, MemGuard>(alloc)); + if (r.second) { + auto &mem = r.first->second; + RETURN_IF_NOT_OK(mem.allocate(1, row_id)); + *out = mem.GetMutablePointer(); + } else { + RETURN_STATUS_UNEXPECTED("Map insert fail."); + } + } + return Status::OK(); +} +Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own + // specific logic + CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + // Get the computed check sum from all ops in the cache miss class + uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]); + // This is a mappable cache op so the id's need to be generated. + // Construct the cache + const bool generate_ids = false; + Status rc = cache_client_->CreateCache(cache_crc, generate_ids); + if (rc.get_code() == StatusCode::kDuplicateKey) { + // We are told the cache has been created already. + MS_LOG(INFO) << "Cache created already"; + rc = Status::OK(); + } + RETURN_IF_NOT_OK(rc); + return Status::OK(); +} +Status CacheMergeOp::ComputeColMap() { + CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); + if (column_name_id_map().empty()) { + column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map(); + } + CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); + return Status::OK(); +} +Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { + RETURN_UNEXPECTED_IF_NULL(out); + // Block until the missing row is in the pool. + RETURN_IF_NOT_OK(use_count_.P()); + std::unique_lock lck(dq_mux_); + CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); + *out = std::move(row_.front()); + row_.pop_front(); + return Status::OK(); +} +void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { + std::unique_lock lck(dq_mux_); + // Technically number of this row shows up in the cache miss stream is equal to the number + // of P() call. However the cleaner wants it too. So we need an extra copy. + if (GetState() == State::kEmpty) { + // We will do a deep copy + for (auto &ts : row) { + auto out_ts = std::make_shared(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes()); + cleaner_copy_.push_back(out_ts); + } + cleaner_copy_.setId(row.getId()); + // Change the state to dirty + SetState(State::kDirty); + } + row_.push_back(std::move(row)); + // Bump up the use count by 1. This wake up any parallel worker which is waiting + // for this row. + use_count_.V(); +} +Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { + RETURN_UNEXPECTED_IF_NULL(out); + // We are not holding any mutex here because the cleaner isn't really touching the deque row_. + // In case we have multiple cleaners and they all see the copy, only one of them will + // get it. + auto expected = State::kDirty; + if (st_.compare_exchange_strong(expected, State::kClean)) { + *out = std::move(cleaner_copy_); + } + return Status::OK(); +} +// Builder constructor. Creates the builder object. +CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + build_op_connector_size_ = cfg->op_connector_size(); + build_num_cleaners_ = 1; +} + +// Check if the required parameters are set by the builder. +Status CacheMergeOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Cache client for CacheMergeOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheMergeOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, build_num_cleaners_, + build_cache_client_, build_sampler_); + return Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status CacheMergeOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CacheMergeOp::EoeReceived(int32_t worker_id) { + // If we are in a repeat path, send the eoe up. + // Otherwise ignore it. + if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { + return DatasetOp::EoeReceived(worker_id); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h new file mode 100644 index 0000000000..60e2ebd0be --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "dataset/core/tensor_row.h" +#include "dataset/engine/cache/cache_client.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/util/queue.h" +#include "dataset/util/semaphore.h" + +namespace mindspore { +namespace dataset { +/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single +/// stream +class CacheMergeOp : public ParallelOp { + public: + // Some handshake structures among the main thread, cleaner threads and parallel op threads. + class TensorRowRequest { + public: + enum class State : uint8_t { + kEmpty = 0, // No row in the deque + kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. + kClean = 2 // The row has been flushed already. + }; + explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} + ~TensorRowRequest() = default; + State GetState() const { return st_; } + void SetState(State newState) { st_ = newState; } + Status Wait(TensorRow *out); + void WakeUpAny(TensorRow &&row); + Status Release(TensorRow *out); + + private: + std::mutex dq_mux_; + std::atomic st_; + Semaphore use_count_; + std::deque row_; + TensorRow cleaner_copy_; + }; + + constexpr static int kCacheHitChildIdx = 0; // Cache hit stream + constexpr static int kCacheMissChildIdx = 1; // Cache miss stream + + /// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of + /// the arguments for constructing it. Use the builder by setting each argument + /// with the provided set methods, and then finally call the build method to execute + /// the actual construction. + class Builder { + public: + /// Builder constructor. Creates the builder object. + /// \note No default args + Builder(); + + /// Default destructor + ~Builder() = default; + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method + /// \param sampler + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief Setter method + /// \param num_cleaners + /// \return Builder setter method returns reference to the builder. + Builder &SetNumCleaner(int32_t num_cleaners) { + build_num_cleaners_ = num_cleaners; + return *this; + } + + /// The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheMergeOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t build_op_connector_size_; + int32_t build_num_cleaners_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + /// Check if the required parameters are set by the builder. + /// \return Status The error code return + Status SanityCheck() const; + }; + + /// \brief Constructor + /// \param numWorkers Number of parallel workers as a derived class of ParallelOp + /// \param opConnector Size Connector size as a derived class of ParallelOp + /// \param numCleaners Number of cleaners to move cache miss rows into the cache server + /// \param cache_client CacheClient to commmunicate with the Cache server + /// \param sampler as a derived class of ParallelOp + CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, + std::shared_ptr cache_client, const std::shared_ptr &sampler); + ~CacheMergeOp(); + void Print(std::ostream &out, bool show_all) const override; + friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) { + mo.Print(out, false); + return out; + } + /// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and + /// the threads for the cleaners. + /// \return + Status operator()() override; + /// \brief Entry function for worker thread that fetch rows from CacheLookupOp + /// \param workerId + /// \return Status object + Status WorkerEntry(int32_t workerId) override; + Status PrepareNodePostAction() override; + /// \brief Entry function for worker thread that fetch rows from the cache miss stream + /// \param workerId + /// \return Status object + Status CacheMissWorkerEntry(int32_t workerId); + Status GetRq(row_id_type row_id, TensorRowRequest **); + + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for eoe handling + /// \param worker_id + /// \return Status object + Status EoeReceived(int32_t worker_id) override; + + protected: + Status ComputeColMap() override; + + private: + std::mutex mux_; + std::map>> cache_miss_map_; + std::unique_ptr> io_que_; + std::shared_ptr cache_client_; + int32_t num_cleaners_; + + /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for + /// moving cache miss TensorRow into the CacheServer. + /// \return Status object + Status Cleaner(); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc new file mode 100644 index 0000000000..149f2b0bbb --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc @@ -0,0 +1,219 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/cache_op.h" + +#include +#include +#include "dataset/core/config_manager.h" +#include "dataset/core/constants.h" +#include "dataset/core/global_context.h" +#include "dataset/engine/datasetops/repeat_op.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" +#include "dataset/util/task_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status CacheOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_, + build_sampler_); + RETURN_IF_NOT_OK((*ptr)->InitCache()); + + return Status::OK(); +} + +// Constructor of CacheOp +CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), + num_guys_in_(0), + phase_(Phase::kBuildPhase) {} + +// Destructor +CacheOp::~CacheOp() = default; + +// Private function for cache setup/init work just after construction +Status CacheOp::InitCache() { return Status::OK(); } + +// This class functor will provide the master loop that drives the logic for performing the work +Status CacheOp::operator()() { + if (!sampler_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "CacheOp requires a sampler before it can be executed!"); + } + RETURN_IF_NOT_OK(RegisterResources()); + // Kick off the workers + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1))); + // required task group sync after launching workers + TaskManager::FindMe()->Post(); + // Wait for the workers to finish caching the rows. + RETURN_IF_NOT_OK(WaitForCachingAllRows()); + RETURN_IF_NOT_OK(FetchSamplesToWorkers()); + return Status::OK(); +} +Status CacheOp::CacheAllRows(int32_t worker_id) { + // If the current phase is to fill the cache, do it then. + if (phase_ == Phase::kBuildPhase) { + // We will take the chance to cache the schema at the server. + // Just do it once and pick one worker to do it. + if (worker_id == 0) { + RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); + } + MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id; + // SAVE mode loop + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + while (!db_ptr->eof()) { + if (!db_ptr->eoe()) { + RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); + } else { + // In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up + // as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the + // the eoe to indicate the end of the epoch, we should next expect to get the eof. + // Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch + // from again. + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + if (!db_ptr->eof()) { + RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child."); + } + } + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + } + } + // Let the main guy know we are done. + auto last_guy_in = num_guys_in_.fetch_add(1); + if ((last_guy_in + 1) == num_workers_) { + rows_cache_done_.Set(); + } else { + // Let's do a sync up here. + RETURN_IF_NOT_OK(rows_cache_done_.Wait()); + } + return Status::OK(); +} +Status CacheOp::WaitForCachingAllRows() { + // Wait for the workers to finish caching the rows. + RETURN_IF_NOT_OK(rows_cache_done_.Wait()); + // Move from build phase to fetch phase if we are the one to fill the cache + if (phase_ == Phase::kBuildPhase) { + RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone()); + // Move to the next phase + phase_ = Phase::kFetchPhase; + } + // Get statistics from the server, and if we are not the one to create the cache, + // wait until the state changed from build phase to fetch base. + CacheClient::ServiceStat stat{}; + bool BuildPhaseDone = true; + do { + RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); + BuildPhaseDone = stat.cache_service_state == static_cast(CacheService::State::kFetchPhase); + if (!BuildPhaseDone) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } while (!BuildPhaseDone); + const row_id_type min_key = stat.min_row_id; + const row_id_type max_key = stat.max_row_id; + num_rows_ = max_key - min_key + 1; + MS_LOG(INFO) << "Number of rows cached: " << num_rows_; + MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; + MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; + // Now all rows are cached and we have done a sync point check up. Next phase is + // is pick up fetch input from sampler and pass up to the caller. + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} +Status CacheOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(CacheAllRows(worker_id)); + RETURN_IF_NOT_OK(FetchFromCache(worker_id)); + return Status::OK(); +} +Status CacheOp::RegisterResources() { + RETURN_IF_NOT_OK(CacheBase::RegisterResources()); + RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); + return Status::OK(); +} + +// Base-class override for setting specific CacheOp configurations. This code will be called +// during the execution tree prepare phase BEFORE traversing down to child operators. +uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; } +// Base-class override for special eoe handler. +// CacheOp must override this because it shall not perform default handling of eoe. Instead +// the CacheOp manages actions related to the end of the epoch. +Status CacheOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +// Base-class override for handling cases when an eof is received. +Status CacheOp::EofReceived(int32_t worker_id) { + // eofReceived is overloaded because we want to manually handle this eof. + // Specifically, the default behaviour is to pack it and flow it up to the next connection. + // In this case, we want a no-op behaviour so that we can perform correct action. + return Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status CacheOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status CacheOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +// A public wrapper for creating the cache through the client +Status CacheOp::CreateCache(uint32_t cache_crc) { + // This is a non-mappable cache op so the id's need to be generated. + // Construct the cache + const bool generate_ids = true; + Status rc = cache_client_->CreateCache(cache_crc, generate_ids); + if (rc.get_code() == StatusCode::kDuplicateKey) { + // We are told the cache has been created already. So we skip the build phase. + phase_ = Phase::kFetchPhase; + rc = Status::OK(); + } + RETURN_IF_NOT_OK(rc); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_op.h new file mode 100644 index 0000000000..6ec7e95ecf --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/cache_op.h @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ + +#include +#include +#include +#include +#include "dataset/engine/datasetops/cache_base_op.h" + +namespace mindspore { +namespace dataset { +/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset. +/// \note For mappable dataset, please see CacheLookupOp. +/// \see CacheLookupOp +class CacheOp : public CacheBase, public RandomAccessOp { + public: + // This CacheOp is for non-mappable case where it is divided into two phases. + // The first phase is we cache all the rows from the child (and let the cache server + // assigns row id). No read access in the first phase. Once the cache is fully built, + // we switch to second phase and fetch requests from the sampler. + enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; + + /// \brief The nested builder class inside of the CacheOp is used to help manage all of + /// the arguments for constructing it. Use the builder by setting each argument + /// with the provided set methods, and then finally call the build method to execute + /// the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method + /// \param rows_per_buffer + /// \return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + rows_per_buffer_ = rows_per_buffer; + return *this; + } + + /// \brief Setter method + /// \param sampler + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t rows_per_buffer_; + int32_t build_op_connector_size_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + /// \brief Check if the required parameters are set by the builder. + /// \return Status The error code return + Status SanityCheck() const; + }; + + /// \brief Constructor of CacheOp + /// \note The builder class should be used to call it. + /// \param num_workers The number of worker threads. + /// \param op_connector_size The size of each queue in the connector. + CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler); + + // Destructor + ~CacheOp(); + + /// \brief Base-class override for setting specific CacheOp configurations. This code will be called + /// during the execution tree prepare phase BEFORE traversing down to child operators. + uint32_t PrepareFlags() const override; + /// \brief Base-class override for special eoe handler. + /// CacheOp must override this because it shall not perform default handling of eoe. Instead + /// the CacheOp manages actions related to the end of the epoch. + /// \return Status - The error code return + Status EoeReceived(int32_t worker_id) override; + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + /// \brief Base-class override for handling cases when an eof is received. + /// \param worker_id - The worker id + /// \return Status - The error code return + Status EofReceived(int32_t worker_id) override; + Status operator()() override; + Status WorkerEntry(int32_t worker_id) override; + /// \brief Base-class override for handling cases if we allow cache miss + bool AllowCacheMiss() override { return false; } + /// \brief Base-class override for the name of this operator + std::string Name() const override { return "CacheOp"; } + /// \brief A public wrapper for creating the cache through the client + /// \param[in] cache_crc The crc that identifies the cache + /// \see cache_pass.cc + /// \return Status return code + Status CreateCache(uint32_t cache_crc); + + private: + WaitPost rows_cache_done_; + std::atomic num_guys_in_; + Phase phase_; + /// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler. + /// \return Status object + Status WaitForCachingAllRows(); + /// \brief For non-mappable dataset, there is a build phase where we cache all the rows. + /// \return Status object + Status CacheAllRows(int32_t worker_id); + Status RegisterResources() override; + /// \brief Private function for cache setup/init work just after construction + /// \return Status The error code return + Status InitCache(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc index 4bada31e7e..2cf2e8045f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc @@ -61,46 +61,39 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const { Status ConcatOp::operator()() { // The children_num_ parameter needs to be put here children_num_ = static_cast(child_.size()); - TaskManager::FindMe()->Post(); std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - int eof_count = 0; - while (eof_count != children_num_) { + while (eof_count == 0) { for (int i = 0; i < children_num_; i++) { - // 1. Throw the eof buffer when meet it - if (buf->eof() || buf->eoe()) { - RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + // 1. Read the first buffer + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + if (buf->eof()) { + eof_count++; + continue; } // 2. Do verification as for column name, column data type and rank of column data - RETURN_IF_NOT_OK(Verify(i, buf)); - + if (!buf->eoe()) { + RETURN_IF_NOT_OK(Verify(i, buf)); + } // 3. Put the data into output_connector while (!buf->eoe() && !buf->eof()) { RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); } - - // 4. Throw the eoe buffer when meet it - if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) { - RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); - } - // 5. Add eoe buffer after get buffer from all child - if (i == (children_num_ - 1)) { - auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - } - if (buf->eof()) { - eof_count++; - } + } + // 4. Add eoe buffer after get buffer from all child + if (eof_count == 0) { + auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); } } - // 6. Add eof buffer in the end manually + CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, + "Something went wrong, eof count does not match the number of children."); + // 5. Add eof buffer in the end manually MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - return Status::OK(); } @@ -126,12 +119,6 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { return Status::OK(); } -Status ConcatOp::PrepareNodePostAction() { - RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); - tree_->AddToEOEOpStack(shared_from_this()); - return Status::OK(); -} - // We need to overwrite the super class ComputeColMap here because the number of children is more than 1. Status ConcatOp::ComputeColMap() { if (column_name_id_map_.empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h index 4bcfdbf6c6..e3dd890d07 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h @@ -75,12 +75,6 @@ class ConcatOp : public PipelineOp { // @return Status - The error code return Status operator()() override; - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - // Op name getter // @return Name of the current Op std::string Name() const override { return "ConcatOp"; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc index 3e31f6c017..a963033833 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc @@ -153,16 +153,38 @@ Status DatasetOp::Remove() { } } + // Finally, clear "this" op's parent and child pointers since we have just + // disconnected it from the tree and invalidate it's fields. + child_.clear(); + parent_.clear(); + operator_id_ = kInvalidOperatorId; + tree_ = nullptr; + return Status::OK(); } -// Getter function to get a shared pointer to our childAdds a operator to become our child. +// Getter function to get a shared pointer to our child std::shared_ptr DatasetOp::child(int32_t child_index) const { + std::shared_ptr return_op = nullptr; + if (child_.empty()) { + return return_op; + } MS_ASSERT(child_index < static_cast(child_.size())); // Return a shared pointer return child_[child_index]; } +// Getter function to get the parent pointer +void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { + if (parent_.empty()) { + // common case if this is a root node + *parent = nullptr; + } else { + MS_ASSERT(parent_index < static_cast(parent_.size())); + *parent = parent_[parent_index]; + } +} + // Creates the connector within this operator void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers @@ -264,19 +286,11 @@ Status DatasetOp::EofReceived(int32_t worker_id) { // During tree prepare phase, operators may have specific pre-operations to perform depending on // their role. -Status DatasetOp::PrepareNodePreAction() { - if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated); - return Status::OK(); -} +Status DatasetOp::PrepareNodePreAction() { return Status::OK(); } + // During tree prepare phase, operators may have specific post-operations to perform depending on // their role. Status DatasetOp::PrepareNodePostAction() { - // If this op does not have any children and it is in a repeat path of the tree... - if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { - // push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator - // above us will consume them. - tree_->AddToEOEOpStack(shared_from_this()); - } // Creating Connector object for each op. // The consumer of the root node is assumed to be one thread. // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. @@ -346,34 +360,13 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_this(), modified); } -// A helper function with some common code that leaf nodes can use during -// prepare phase for checking if they need to assign a sampler to the cache. -Status DatasetOp::SaveSamplerForCache(bool random_access_op) { - // If we are a descendant under a cache op and we have a sampler, then save this sampler - // to a stack so that the cache can pick it up during it's processing above us. - if (sampler_) { - if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { - // use move semantic to set our sampler_ to null after the move. This is okay because a sampler is - // useless to a random data op. It was only being used as a temporary holding until the cache can - // be created - tree_->AddToSamplerStack(sampler_); - MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling."; - } else if (!random_access_op) { - // A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf. - // This is an error because that type of leaf does not use sampling unless there's a cache to hook it into. - RETURN_STATUS_UNEXPECTED( - "Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree"); - } - } - - if (!random_access_op) { - // Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache - // we can remove it now from the base. - sampler_.reset(); - } - +// Getter for the sampler, and it also removes the sampler from the op +Status DatasetOp::FetchRemoveSampler(std::shared_ptr *sampler) { + *sampler = sampler_; // It's okay if it sampler_ points to nullptr + sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler return Status::OK(); } + uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { std::stringstream ss; op->tree_->Print(ss, op); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h index ab5cb90357..f2a8c23282 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h @@ -45,10 +45,10 @@ class DatasetOp : public std::enable_shared_from_this { public: static constexpr int32_t kInvalidOperatorId = -1; - // Flags that control operator runtime behaviours + // Operator control flags enum OpControlFlags { kDeOpNone = 0, - kDeOpRepeated = 1, // Operator is a leaf node in a repeat path + kDeOpRepeated = 1, // Operator is a node in a repeat path kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop }; @@ -71,17 +71,23 @@ class DatasetOp : public std::enable_shared_from_this { /// \param child - shared pointer to the child to remove. Status RemoveChild(std::shared_ptr child); - /// \brief Removes this node from the tree and connects it's parent/child together. + /// \brief Removes this node from the tree and connects it's parent/child together /// \return Status eerror code returned Status Remove(); /// \brief Getter function to get a shared pointer to our child - /// \param child_index - An operator can have n children. Indicates choose which child to return. + /// \param[in] child_index An operator can have n children. Indicates which child to return. + /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index std::shared_ptr child(int32_t child_index) const; - /// \brief Inserts a operator as the parent current op. - /// Inserted op will become the sole parent of the current op. - /// The existing parent of the current op will be transferred to the inserted op. + /// \brief Getter function to get the pointer to our parent + /// If there are no parents, it returns null regardless of the given index + /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. + void Parent(DatasetOp **parent, int32_t parent_index) const; + + // Inserts a operator as the parent current op. + // Inserted op will become the sole parent of the current op. + // The existing parent of the current op will be transferred to the inserted op. Status InsertAsParent(std::shared_ptr to_add); /// \brief Creates the connector within this operator @@ -161,16 +167,6 @@ class DatasetOp : public std::enable_shared_from_this { /// \return Status - The error code return virtual Status Reset(); - /// \brief This calls the reset function on this subtree in pre-order - /// \return Status - The error code return - virtual Status ResetSubtree() { - RETURN_IF_NOT_OK(Reset()); - for (const auto &c : child_) { - RETURN_IF_NOT_OK(c->ResetSubtree()); - } - return Status::OK(); - } - /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on /// their role. /// \notes Derived versions of this function should always call it's superclass version first @@ -296,7 +292,12 @@ class DatasetOp : public std::enable_shared_from_this { /// \return Shared pointer to the sampler (may return nullptr) std::shared_ptr sampler() { return sampler_; } - /// Computes a CRC value for the operator + /// \brief Getter for the sampler, and it also removes the sampler from the op + /// \param[out] sampler A pointer to the output sampler that was removed + /// \return Status error code + Status FetchRemoveSampler(std::shared_ptr *sampler); + + // Computes a CRC value for the operator static uint32_t GenerateCRC(const std::shared_ptr &op); /// \brief A helper templated function for casting "this" pointer to shared_ptr @@ -307,17 +308,24 @@ class DatasetOp : public std::enable_shared_from_this { return std::static_pointer_cast(shared_from_this()); } - protected: - /// Adds a parent operator to this operator - /// \notes External callers do not have access to this function. - /// \param parent - The parent node to add - void AddParent(DatasetOp *parent); + /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. + void SetSampler(std::shared_ptr sampler) { sampler_ = sampler; } + + /// \brief Checks if this is a leaf node (0 children) + /// \return boolean returns true if it's a leaf + bool IsLeaf() { return (child_.empty()); } - /// Removes a parent operator from this operator - /// \notes External callers do not have access to this function. - /// \param parent - The parent node to remove + protected: + /// \brief Removes a parent operator from this operator + /// \notes External callers do not have access to this function + /// \param[in] parent The parent node to remove void RemoveParent(const DatasetOp *parent); + /// \brief Adds a parent operator to this operator + /// \notes External callers do not have access to this function + /// \param[in] parent The parent node to add + void AddParent(DatasetOp *parent); + /// Compute the current op's column map using its child's column map. /// Get called during the tree post-prepare phase in PrepareNodePostAction. /// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. @@ -325,12 +333,6 @@ class DatasetOp : public std::enable_shared_from_this { /// \return - Status virtual Status ComputeColMap(); - /// A helper function with some common code that leaf nodes can use during - /// pre/pare phase for checking if they need to assign a sampler to the cache. - /// \param random_access_op - indicate if this is a mappable random access leaf or not - /// \return - Status - Status SaveSamplerForCache(bool random_access_op); - std::vector> child_; // Child nodes std::vector parent_; // Parent nodes. No ownership std::shared_ptr sampler_; // Some leaf ops might have a sampler diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc index 4999dddd02..a0de649284 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc @@ -77,26 +77,6 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { } } -// Base-class override for executing specific RepeatOp configurations. This code will be called -// during the execution tree prepare phase when it is visiting this operator. -Status RepeatOp::PrepareNodePostAction() { - // Run any common code from super class first before adding our own specific logic - RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); - std::shared_ptr leaf_op = tree_->PopFromEOEOpStack(); - while (leaf_op != nullptr) { - // Track the leaf operators that are under this repeat op. - eoe_ops_.push_back(leaf_op); - leaf_op = tree_->PopFromEOEOpStack(); - } - // Push ourselves to the stack in case one of our ascendants is repeat too. - tree_->AddToEOEOpStack(shared_from_this()); - return Status::OK(); -} - -// Base-class override for setting specific RepeatOp configurations. This code will be called -// during the execution tree prepare phase BEFORE traversing down to child operators. -uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; } - // This function returns the buffer that is at the top of our output connector. The caller is // typically our parent node, when the parent is asking us to provide the next buffer of data. // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get @@ -130,7 +110,8 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t wo // Base-class override for handling cases when an eoe is received. Status RepeatOp::EoeReceived(int32_t worker_id) { repeat_count_++; - MS_LOG(DEBUG) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << "."; + MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ + << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); // If we've reached the requested repeat count, then flag the eoe nodes @@ -149,8 +130,12 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { return Status::OK(); } - // base-class ResetSubtree - return (DatasetOp::ResetSubtree()); + // Invoke a reset against the eoe nodes only. + for (auto &eoe_op : eoe_ops_) { + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + + return Status::OK(); } // Class functor operator () override. @@ -178,6 +163,18 @@ int32_t RepeatOp::num_consumers() const { } } +// Drive reset actions if needed +Status RepeatOp::Reset() { + // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. + // In that case, we now have to bounce the reset down to our own eoe ops. + MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; + for (auto &eoe_op : eoe_ops_) { + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + state_ = OpState::kDeOpRunning; + return Status::OK(); +} + int32_t RepeatOp::num_producers() const { if (child_.empty() || child_[0] == nullptr) { MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; @@ -187,6 +184,12 @@ int32_t RepeatOp::num_producers() const { } } +// Pre-Visitor accept method for NodePass +Status RepeatOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + // Visitor accept method for NodePass Status RepeatOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h index bba85c3bb5..7993737aeb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h @@ -18,6 +18,7 @@ #include #include +#include #include #include "dataset/engine/datasetops/pipeline_op.h" @@ -82,14 +83,6 @@ class RepeatOp : public PipelineOp { // @return Status - The error code return Status operator()() override; - // Base-class override for setting specific RepeatOp configurations. This code will be called - // during the execution tree prepare phase BEFORE traversing down to child operators. - uint32_t PrepareFlags() const override; - - // Base-class override for executing specific RepeatOp configurations. This code will be called - // during the execution tree post-prepare phase when it is visiting this operator. - Status PrepareNodePostAction() override; - // This function returns the buffer that is at the top of our output connector. The caller is // typically our parent node, when the parent is asking us to provide the next buffer of data. // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get @@ -110,6 +103,10 @@ class RepeatOp : public PipelineOp { // @param worker_id - The worker id Status EofReceived(int32_t worker_id) override; + /// \brief reset Op + /// \@return Status - The error code return + Status Reset() override; + // Base-class override. Return the number of workers in the first parent. // @param workerId - The worker id int32_t num_consumers() const override; @@ -118,16 +115,26 @@ class RepeatOp : public PipelineOp { // @param workerId - The worker id int32_t num_producers() const override; - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit Status Accept(NodePass *p, bool *modified) override; // Op name getter // @return Name of the current Op std::string Name() const override { return "RepeatOp"; } + /// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes + /// \param[in] eoe_op The input leaf/eoe operator to add to the list + void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } + private: int32_t max_repeats_; // The number of repeats that the user requested int32_t repeat_count_; // A counter for the current number of executed repeats diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc index c7a4269a39..db357f42ec 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc @@ -22,6 +22,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/data_schema.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" #include "dataset/kernels/image/image_utils.h" namespace mindspore { @@ -408,6 +409,12 @@ Status CelebAOp::Reset() { return Status::OK(); } +// Visitor accept method for NodePass +Status CelebAOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + Status CelebAOp::ComputeColMap() { // Set the column name map (base class field) if (column_name_id_map_.empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h index a6fa495a14..fa81babe4c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h @@ -169,6 +169,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp { // @return Status - The error code return Status AddIOBlock(std::unique_ptr *data_buffer); + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + // Op name getter // @return Name of the current Op std::string Name() const { return "CelebAOp"; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc index 8dd615a8c1..d378933c04 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc @@ -26,6 +26,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -450,6 +451,12 @@ Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t * } } +// Visitor accept method for NodePass +Status CifarOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + Status CifarOp::ComputeColMap() { // set the column name map (base class field) if (column_name_id_map_.empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h index 917b23db94..24324bbebb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h @@ -155,6 +155,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // @return static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + // Op name getter // @return Name of the current Op std::string Name() const override { return "CifarOp"; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc index 92f6794769..7d14163544 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc @@ -24,6 +24,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -624,6 +625,12 @@ Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file, return Status::OK(); } +// Visitor accept method for NodePass +Status CocoOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + Status CocoOp::ComputeColMap() { // Set the column name map (base class field) if (column_name_id_map_.empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h index 3791853798..2a93d26195 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h @@ -200,6 +200,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp { static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, std::vector>> *output_class_indexing); + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc index e65da8707b..4f9a12bd65 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc @@ -26,6 +26,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -416,6 +417,12 @@ Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dic return Status::OK(); } +// Visitor accept method for NodePass +Status ManifestOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + Status ManifestOp::ComputeColMap() { // Set the column name map (base class field) if (column_name_id_map_.empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h index c180ea581d..864abf676c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h @@ -172,6 +172,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, std::map *output_class_indexing); + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + // Op name getter // @return Name of the current Op std::string Name() const override { return "ManifestOp"; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc index e98f8ae8c1..8a75cdc579 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc @@ -23,6 +23,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -428,6 +429,12 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { return Status::OK(); } +// Visitor accept method for NodePass +Status MnistOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + Status MnistOp::ComputeColMap() { // set the column name map (base class field) if (column_name_id_map_.empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h index 9bd6276a11..e57dc21d60 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h @@ -152,6 +152,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp { // @return static Status CountTotalRows(const std::string &dir, int64_t *count); + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + // Op name getter // @return Name of the current Op std::string Name() const override { return "MnistOp"; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc index 3a865d8d69..f13de2e5c9 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc @@ -22,6 +22,7 @@ #include "dataset/util/random.h" #include "dataset/util/wait_post.h" #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -406,6 +407,12 @@ Status RandomDataOp::Reset() { return Status::OK(); } +// Visitor accept method for NodePass +Status RandomDataOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + Status RandomDataOp::ComputeColMap() { // Extract the column name mapping from the schema and save it in the class. if (column_name_id_map_.empty()) { @@ -415,15 +422,5 @@ Status RandomDataOp::ComputeColMap() { } return Status::OK(); } - -// During tree prepare phase, operators may have specific post-operations to perform depending on -// their role. -Status RandomDataOp::PrepareNodePostAction() { - // Run common code from super class before adding RandomDataOp specific handling - RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); - // Specific handling for this op, we need to do cache op work to assign the sampler to the cache. - RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false)); - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h index b2af27dda3..76d781ee1c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h @@ -203,12 +203,6 @@ class RandomDataOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "RandomDataOp"; } - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - private: /** * The entry point code for when workers are launched @@ -266,6 +260,12 @@ class RandomDataOp : public ParallelOp { return ++buffer_id_; } + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + // Private function for computing the assignment of the column name map. // @return - Status Status ComputeColMap() override; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 48f13ff766..6e6d885cb1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -1019,31 +1019,28 @@ Status TFReaderOp::ComputeColMap() { return Status::OK(); } +// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing +// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so +// that this tf reader will produce the full set of data into the cache. +void TFReaderOp::MakeSimpleProducer() { + device_id_ = 0; + num_devices_ = 1; + total_rows_ = 0; + shuffle_files_ = false; + equal_rows_per_shard_ = false; +} + // During tree prepare phase, operators may have specific post-operations to perform depending on // their role. Status TFReaderOp::PrepareNodePostAction() { // Run common code from super class before adding TFReaderOp specific handling RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); - // Specific handling for this op, we need to do cache op work so assign the sampler to the cache - // TF is a special case because it can support file-based sharding/shuffling, or, if there - // is a cache, then it can also do row-based sampler using the sampler on the cache. - // Thus, pass true for random access op flag when saving the sampler. This is a special case, - // since usually a non-mappable dataset would pass false here. - RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true)); - // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into // a simpler producer of all data (no shuffling or sharding or anything) - if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { - device_id_ = 0; - num_devices_ = 1; - total_rows_ = 0; - shuffle_files_ = false; - equal_rows_per_shard_ = false; - sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment) - } else { + if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { // This sanity check had been delayed until now in the prepare loop. - // If we are not in a cache path, then we can validate the the file-based sharding config. + // If we are not in a cache path, then we can validate the file-based sharding config. // If we are in a cache path, there is no file-based sharding so the check is not correct in that // situation. if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast(num_devices_)) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h index 9226c4c6c5..2613bc5e46 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h @@ -246,6 +246,11 @@ class TFReaderOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return dataset_files_list_; } + /// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so + /// that this tf reader will produce the full set of data into the cache. + void MakeSimpleProducer(); + // During tree prepare phase, operators may have specific post-operations to perform depending on // their role. // @notes Derived versions of this function should always call it's superclass version first diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc index 16a0d64c94..27a343c973 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc @@ -25,6 +25,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" using tinyxml2::XMLDocument; using tinyxml2::XMLElement; @@ -449,6 +450,11 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t return Status::OK(); } +// Visitor accept method for NodePass +Status VOCOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} Status VOCOp::ComputeColMap() { // Set the column name map (base class field) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h index 87324b1b7a..ec46a3c7b1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h @@ -205,6 +205,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp { static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict, std::map *output_class_indexing); + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + // Op name getter // @return Name of the current Op std::string Name() const override { return "VOCOp"; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc index 8bc449cdc9..b9fd8a0663 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -127,12 +127,6 @@ Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptrAddToEOEOpStack(shared_from_this()); - return Status::OK(); -} - // Visitor accept method for NodePass Status TakeOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h index 9619a4409d..07626d5f1f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h @@ -78,12 +78,6 @@ class TakeOp : public PipelineOp { // @return Status - The error code return Status operator()() override; - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - // Base-class override for NodePass visitor acceptor. // @param p - Pointer to the NodePass to be accepted. // @param modified - Whether this node visit modified the pipeline. diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc index 385722e257..18ef8d6bc7 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/dataset/engine/execution_tree.cc @@ -21,6 +21,8 @@ #include "dataset/util/task_manager.h" #include "dataset/engine/opt/pass.h" #include "dataset/engine/opt/pre/removal_pass.h" +#include "dataset/engine/opt/pre/cache_transform_pass.h" +#include "dataset/engine/opt/post/repeat_pass.h" #include "dataset/engine/perf/profiling.h" #include "dataset/engine/perf/monitor.h" @@ -215,18 +217,33 @@ Status ExecutionTree::PrepareTreePreAction() { bool modified = false; std::vector> pre_actions; // Construct pre actions - MS_LOG(INFO) << "Running pre pass"; - pre_actions.push_back(std::make_unique(RemovalPass())); + MS_LOG(INFO) << "Running pre pass loops."; + pre_actions.push_back(std::make_unique()); + pre_actions.push_back(std::make_unique()); // Apply pre action passes for (auto &pass : pre_actions) { RETURN_IF_NOT_OK(pass->Run(this, &modified)); } + MS_LOG(INFO) << "Pre passes complete."; return Status::OK(); } Status ExecutionTree::PrepareTreePostAction() { // The tree is ready to be prepared. tree_state_ = kDeTStatePrepare; + + bool modified = false; + std::vector> post_actions; + // Construct pre actions + MS_LOG(INFO) << "Running post pass loops."; + post_actions.push_back(std::make_unique()); + + // Apply post action passes + for (auto &pass : post_actions) { + RETURN_IF_NOT_OK(pass->Run(this, &modified)); + } + MS_LOG(INFO) << "Post passes complete."; + return Status::OK(); } @@ -280,31 +297,5 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) return Status::OK(); } - -// Adds an operator to the eoe operator stack during prepare phase. -void ExecutionTree::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } - -// Pops an operator from the eoe operator stack during prepare phase. -std::shared_ptr ExecutionTree::PopFromEOEOpStack() { - std::shared_ptr top_op = nullptr; - if (!eoe_stack_.empty()) { - top_op = eoe_stack_.top(); - eoe_stack_.pop(); - } - return top_op; -} - -// Adds a sampler to the sampler stack during prepare phase. -void ExecutionTree::AddToSamplerStack(std::shared_ptr sampler) { sampler_stack_.push(sampler); } - -// Pops an operator from the sampler stack during prepare phase. -std::shared_ptr ExecutionTree::PopFromSamplerStack() { - std::shared_ptr top_sampler = nullptr; - if (!sampler_stack_.empty()) { - top_sampler = sampler_stack_.top(); - sampler_stack_.pop(); - } - return top_sampler; -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.h b/mindspore/ccsrc/dataset/engine/execution_tree.h index 5ebfa539ad..92debafa39 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/dataset/engine/execution_tree.h @@ -200,24 +200,6 @@ class ExecutionTree { // @return Status - The error code return Status PrepareNode(const std::shared_ptr &dataset_op); - /// Adds an operator to the eoe operator stack during prepare phase. - /// \param op - The dataset op to work add to eoe stack - /// \return Status - The error code return - void AddToEOEOpStack(std::shared_ptr dataset_op); - - /// Pops an operator from the eoe operator stack during prepare phase. - /// \return shared_ptr to the popped operator - std::shared_ptr PopFromEOEOpStack(); - - /// Adds a sampler to the sampler stack during prepare phase. - /// \param samplerop - The dataset op to work add to eoe stack - /// \return Status - The error code return - void AddToSamplerStack(std::shared_ptr sampler); - - /// Pops an operator from the sampler stack during prepare phase. - /// \return shared_ptr to the popped operator - std::shared_ptr PopFromSamplerStack(); - // Return the pointer to the TaskGroup // @return raw pointer to the TaskGroup TaskGroup *AllTasks() const { return tg_.get(); } @@ -248,8 +230,6 @@ class ExecutionTree { TreeState tree_state_; // Tracking the current tree state std::unique_ptr perf_monitor_; // Performance Monitor std::unique_ptr profiling_manager_; // Profiling manager - std::stack> eoe_stack_; // A stack used during prepare phase - std::stack> sampler_stack_; // A stack used during prepare phase }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt index 080d968cfc..e867c25285 100644 --- a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt @@ -2,6 +2,9 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(engine-opt OBJECT pass.cc + post/repeat_pass.cc + pre/cache_pass.cc + pre/cache_transform_pass.cc pre/removal_nodes.cc pre/removal_pass.cc util/printer_pass.cc diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.cc b/mindspore/ccsrc/dataset/engine/opt/pass.cc index aa33e59b8f..17689224ea 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/dataset/engine/opt/pass.cc @@ -16,6 +16,9 @@ #include "dataset/engine/opt/pass.h" #include "dataset/engine/datasetops/batch_op.h" +#include "dataset/engine/datasetops/cache_op.h" +#include "dataset/engine/datasetops/cache_merge_op.h" +#include "dataset/engine/datasetops/cache_lookup_op.h" #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/device_queue_op.h" #include "dataset/engine/datasetops/map_op.h" @@ -24,8 +27,15 @@ #include "dataset/engine/datasetops/repeat_op.h" #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/datasetops/shuffle_op.h" +#include "dataset/engine/datasetops/source/celeba_op.h" +#include "dataset/engine/datasetops/source/cifar_op.h" +#include "dataset/engine/datasetops/source/coco_op.h" +#include "dataset/engine/datasetops/source/manifest_op.h" #include "dataset/engine/datasetops/source/mindrecord_op.h" +#include "dataset/engine/datasetops/source/mnist_op.h" +#include "dataset/engine/datasetops/source/random_data_op.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "dataset/engine/datasetops/source/voc_op.h" #ifdef ENABLE_PYTHON #include "dataset/engine/datasetops/filter_op.h" #include "dataset/engine/datasetops/source/generator_op.h" @@ -145,6 +155,11 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { } #endif +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); @@ -164,5 +179,70 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); } + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.h b/mindspore/ccsrc/dataset/engine/opt/pass.h index dd9b65b283..8489faa23a 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/dataset/engine/opt/pass.h @@ -47,6 +47,10 @@ class FilterOp; class GeneratorOp; #endif +class RandomDataOp; + +class RepeatOp; + class TakeOp; class ZipOp; @@ -55,6 +59,24 @@ class DeviceQueueOp; class ImageFolderOp; +class CacheOp; + +class MnistOp; + +class ManifestOp; + +class CifarOp; + +class VOCOp; + +class CocoOp; + +class CelebAOp; + +class CacheMergeOp; + +class CacheLookupOp; + // The base class Pass is the basic unit of tree transformation. // The actual implementation of the passes will be derived from here. class Pass : public std::enable_shared_from_this { @@ -138,14 +160,42 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); #endif + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + private: // Helper function to perform DFS visit Status DFSNodeVisit(std::shared_ptr node, bool *modified); diff --git a/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc new file mode 100644 index 0000000000..9f7a561aa6 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "dataset/engine/opt/post/repeat_pass.h" +#include "dataset/engine/datasetops/repeat_op.h" +#include "dataset/engine/datasetops/cache_op.h" +#include "dataset/engine/datasetops/cache_lookup_op.h" +#include "dataset/engine/datasetops/cache_merge_op.h" + +namespace mindspore { +namespace dataset { + +RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} + +// Identifies the subtree below this node as being in a repeated path of the tree. +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // If we are already repeated, then this is a nested repeat. + if (is_repeated_) { + nested_repeats_++; + } + is_repeated_ = true; + return Status::OK(); +} + +// Identifies the subtree below this node as being in a cache merge path +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that we're under a merge op + is_merge_ = true; + return Status::OK(); +} + +// Hooks up any identified eoe nodes under this repeat. +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + node->AddToEoeList(leaf_op); + leaf_op = PopFromEOEOpStack(); + } + + // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up + // and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. + if (is_merge_ && cache_lookup_) { + cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); + node->AddToEoeList(std::move(cache_lookup_)); + } + + // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. + // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. + if (nested_repeats_ > 0) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + AddToEOEOpStack(node); + nested_repeats_--; + } + + // If we are not nested, or we were the top-most repeat, now we clear the flag + if (nested_repeats_ == 0) { + is_repeated_ = false; + } + + return Status::OK(); +} + +// CacheOp removes previous leaf ops and replaces them with itself +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + // if we are a cache within a repeat path of the tree, then there will be + // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the + // repeat or epoch ctrl operators can work with them for repeat activity during runtime. + // However, since a cache is present: + // - unflag those ops as being repeated ops + // - remove them from the eoe op stack so that repeat op above in the tree won't know about them + // - add ourself (the cache op), as an eoe op + // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead + // the repeating behaviours shall be invoked against the cache op. + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); + leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); + leaf_op = PopFromEOEOpStack(); + } + AddToEOEOpStack(std::static_pointer_cast(node)); + } + + return Status::OK(); +} + +// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up +// for use with a controlling repeat above it. +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // If we are in a repeat path, then set our repeated flag + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + + // if we are a leaf node then save ourself in a stack for the repeat operator above us + if (node->IsLeaf()) { + AddToEOEOpStack(node); + } + } + return Status::OK(); +} + +// Turns off the tracking for operations under merge op +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Setting the flag is needed since we didn't call the base class DatasetOp version + if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); + is_merge_ = false; + cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed + return Status::OK(); +} + +// Saves the lookup up in case it needs to be referenced by a repeat +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + if (!node->IsLeaf()) { + // By definition, the CacheLookup must be a leaf op. Make that clear here. + RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); + } + + // If we are in a repeat path already, then there must be a repeat above the merge op + // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + AddToEOEOpStack(node); + } else { + // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we + // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself + // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. + cache_lookup_ = std::static_pointer_cast(node); + } + return Status::OK(); +} + +// Adds an operator to the eoe operator stack save area +void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } + +// Pops an operator from the eoe operator stack save area +std::shared_ptr RepeatPass::PopFromEOEOpStack() { + std::shared_ptr top_op = nullptr; + if (!eoe_stack_.empty()) { + top_op = eoe_stack_.top(); + eoe_stack_.pop(); + } + return top_op; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h new file mode 100644 index 0000000000..3f5f347a30 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ +#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ + +#include +#include +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class RepeatPass repeat_pass.h +/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references +/// to the eoe-producing (typically leaf) nodes underneath it. +class RepeatPass : public NodePass { + public: + /// \brief Constructor + RepeatPass(); + + /// \brief Identifies the subtree below this node as being in a repeated path of the tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the subtree below this node as being in a cache merge path + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Hooks up any identified eoe nodes under this repeat. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief CacheOp removes previous leaf ops and replaces them with itself + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Turns of the tracking for operations under merge op + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Saves the lookup up in case it needs to be referenced by a repeat + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up + /// for use with a controlling repeat above it. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + /// \brief Adds an operator to the eoe operator stack save area + /// \param op - The dataset op to work add to eoe stack + /// \return Status - The error code return + void AddToEOEOpStack(std::shared_ptr dataset_op); + + /// \brief Pops an operator from the eoe operator stack save area + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromEOEOpStack(); + + bool is_repeated_; // T/F if we are processing under a repeat + bool is_merge_; // T/F if we are processing under a cache merge op + int32_t nested_repeats_; // A counter for nested repeats + std::stack> eoe_stack_; // A save area for leaf/eoe ops + std::shared_ptr cache_lookup_; // A save area for a cache lookup op +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc new file mode 100644 index 0000000000..ae0f4d3a04 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "dataset/engine/opt/pre/cache_pass.h" +#include "dataset/engine/opt/pre/cache_transform_pass.h" +#include "dataset/engine/datasetops/cache_op.h" +#include "dataset/engine/datasetops/source/celeba_op.h" +#include "dataset/engine/datasetops/source/generator_op.h" +#include "dataset/engine/datasetops/source/manifest_op.h" +#include "dataset/engine/datasetops/source/mnist_op.h" +#include "dataset/engine/datasetops/source/voc_op.h" +#include "dataset/engine/datasetops/source/cifar_op.h" +#include "dataset/engine/datasetops/source/coco_op.h" +#include "dataset/engine/datasetops/source/image_folder_op.h" +#include "dataset/engine/datasetops/source/random_data_op.h" +#include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "dataset/engine/datasetops/source/mindrecord_op.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CachePass::CachePass(CacheTransformPass *transform_pass) + : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); + } + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache +// transformation +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + is_caching_ = false; // We a no longer in a cache subtree. clear the flag. + if (leaf_op_) { + MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; + // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, + // using base class pointers. + transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); + } else { + // If there was no leaf_op set, then this is a non-mappable scenario. + + if (sampler_) { + // Grab the sampler that was saved from the leaf and plug it into the cache op + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; + } else { + // We're a cache op but no sampler was saved from leaf, so create a default sampler + int64_t num_samples = 0; + int64_t start_index = 0; + sampler_ = std::make_shared(num_samples, start_index); + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; + } + + // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache + uint32_t cache_crc = DatasetOp::GenerateCRC(node); + RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); + } + + return Status::OK(); +} + +// Common code for mappable leaf setup. +Status CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // If we are a leaf in the caching path, then save this leaf. + if (is_caching_) { + MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; + leaf_op_ = std::move(leaf_op); + } + return Status::OK(); +} + +// Common code for non mappable leaf setup. +Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf + // as save it for use by cache op in ascendant tree. + if (is_caching_) { + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); + MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; + } else { + // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can + // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) + std::shared_ptr sampler_from_leaf; + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); + } + return Status::OK(); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic + // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h b/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h new file mode 100644 index 0000000000..c842e54bbf --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ + +#include +#include +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class CacheTransformPass; + +/// \class CachePass cache_pass.h +/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache +/// transformation. It works in conjunction with the CacheTransformPass +class CachePass : public NodePass { + public: + /// \brief Constructor + /// \param[in] transform_pass Raw pointer back to controlling tree pass + explicit CachePass(CacheTransformPass *transform_pass); + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache + /// transformation + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + /// \brief Common code for mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status MappableCacheLeafSetup(std::shared_ptr leaf_op); + + /// \brief Common code for non-mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); + + bool is_caching_; + std::shared_ptr leaf_op_; + std::shared_ptr sampler_; + CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc new file mode 100644 index 0000000000..df4933fa1c --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "dataset/engine/opt/pre/cache_pass.h" +#include "dataset/engine/opt/pre/cache_transform_pass.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/engine/cache/cache_client.h" +#include "dataset/engine/datasetops/cache_lookup_op.h" +#include "dataset/engine/datasetops/cache_merge_op.h" +#include "dataset/engine/datasetops/cache_op.h" + +namespace mindspore { +namespace dataset { + +// constructor +CacheTransformPass::CacheTransformPass() {} + +// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations +Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { + MS_LOG(INFO) << "Pre pass: Cache transform pass started."; + // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will + // use to execute a transform. + std::unique_ptr cache_pass = std::make_unique(this); + RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); + + // Then, execute the transform for each pair + for (auto cache_pair : cache_pairs_) { + MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; + ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); + } + MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; + return Status::OK(); +} + +// Helper function to execute the cache transformation. +Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, + std::shared_ptr cache_op, + std::shared_ptr cache_client) { + // Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was + // the root node. It is also possible that cache_child == leaf_op + std::shared_ptr cache_child = cache_op->child(0); + DatasetOp *cache_parent = nullptr; + cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent + + // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. + std::shared_ptr leaf_sampler = leaf_op->sampler(); + + // Construct the merge op with defaults + std::shared_ptr merge_op; + CacheMergeOp::Builder merge_builder; + RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op)); + RETURN_IF_NOT_OK(tree->AssociateNode(merge_op)); + + // Construct the cache lookup op with defaults + std::shared_ptr cache_lookup_op; + CacheLookupOp::Builder lookup_builder; + RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op)); + RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op)); + + // Overwrite the old sampler in this leaf op to become the lookup op + leaf_op->SetSampler(cache_lookup_op); + + // If the cache had a parent, then go into that parent to remove the cache from it's child list and then + // replace it with the merge op. + if (cache_parent != nullptr) { + RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op)); + RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op)); + } else { + // If we didn't have a parent, then the merge op is the root node + RETURN_IF_NOT_OK(tree->AssignRoot(merge_op)); + } + + // Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op. + // We maintain a local pointer to the old child though. + RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child)); + + // Connect the merge op + RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op))); + RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child))); + + // At this point, the cache op has already had it's children and parents taken away. Calling remove + // on it at this point will not do any node hookups, and instead set internal fields to invalid. + RETURN_IF_NOT_OK(cache_op->Remove()); + + return Status::OK(); +} + +// Assigns the leaf and cache operators that are involved in a cache transformation +void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr leaf_op, + std::shared_ptr cache_op) { + cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h new file mode 100644 index 0000000000..dc31d76d80 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ + +#include +#include +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +class CacheClient; + +/// \class CacheTransformPass cache_transform_pass.h +/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching +/// operations +class CacheTransformPass : public TreePass { + public: + /// \brief Constructor + CacheTransformPass(); + + /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(ExecutionTree *tree, bool *modified) override; + + /// \brief Assigns the leaf and cache operators that are involved in a cache transformation + /// \param[in] leaf_op The leaf operator involved in the cache transform + /// \param[in] cache_op The cache operator involved in the cache transform + void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); + + private: + /// \brief Helper function to execute the cache transformation. + /// + /// Input: + /// Sampler + /// | + /// LeafOp --> OtherOps --> CacheOp + /// + /// Transformed: + /// Sampler --> CacheLookupOp ----------------> + /// | | + /// | MergeOp + /// | | + /// LeafOp --> OtherOps --> + /// + /// \param[in] leaf_op The leaf node in the transform + /// \param[in] cache_op The cache op in the transform (will get removed) + /// \param[in] cache_client The cache client + /// \return Status The error code return + Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, + std::shared_ptr cache_op, std::shared_ptr cache_client); + + // The two operators that work together to establish the cache transform + std::vector, std::shared_ptr>> cache_pairs_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc index 831a2a76ba..e361015e48 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc +++ b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc @@ -24,12 +24,28 @@ namespace dataset { RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} +// Identifies the subtree below this node as a cached descendant tree. +Status RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree +Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; + is_caching_ = false; + return Status::OK(); +} + // Perform ShuffleOp removal check. Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { *modified = false; // If we are in a cache descendant tree, then this shuffle op needs to be removed if (is_caching_) { - MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; + MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; if (removal_pass_) { removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); } else { diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h index 11ef37d80c..7e4a89e3da 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h +++ b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h @@ -34,6 +34,18 @@ class RemovalNodes : public NodePass { /// \param[in] removal_pass Raw pointer back to controlling tree pass explicit RemovalNodes(RemovalPass *removal_pass); + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Perform ShuffleOp removal check /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc index 31ec31234f..db5e37a085 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc +++ b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc @@ -28,6 +28,7 @@ RemovalPass::RemovalPass() {} // Runs a removal_nodes pass first to find out which nodes to remove, then removes them. Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { + MS_LOG(INFO) << "Pre pass: removal pass started."; // Create the removal node pass which can identify which nodes need to be removed. std::unique_ptr removal_nodes = std::make_unique(this); RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); @@ -36,6 +37,7 @@ Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { for (auto node : removal_nodes_) { node->Remove(); } + MS_LOG(INFO) << "Pre pass: removal pass complete."; return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/util/allocator.h b/mindspore/ccsrc/dataset/util/allocator.h index 50a9cadbe3..1998716438 100644 --- a/mindspore/ccsrc/dataset/util/allocator.h +++ b/mindspore/ccsrc/dataset/util/allocator.h @@ -87,8 +87,9 @@ class Allocator { std::shared_ptr pool_; }; /// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will -/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator. -/// Default to std::allocator +/// be released when the object goes out of scope +/// \tparam T The type of object to be allocated +/// \tparam C Allocator. Default to std::allocator template > class MemGuard { public: @@ -168,7 +169,7 @@ class MemGuard { private: allocator alloc_; - std::unique_ptr> ptr_; + std::unique_ptr ptr_; size_t n_; }; } // namespace dataset diff --git a/mindspore/ccsrc/dataset/util/cache_pool.cc b/mindspore/ccsrc/dataset/util/cache_pool.cc index 92504cd063..7d7a2a4a94 100644 --- a/mindspore/ccsrc/dataset/util/cache_pool.cc +++ b/mindspore/ccsrc/dataset/util/cache_pool.cc @@ -98,11 +98,6 @@ Status CachePool::Insert(const std::vector &buf, CachePool::key_t } catch (std::bad_alloc &e) { if (sm_ != nullptr) { RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); - // We have an assumption 0 is not a valid key from the design of AutoIndexObj. - // Make sure it is not 0. - if (bl.storage_key == 0) { - RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected"); - } } else { return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); } diff --git a/mindspore/ccsrc/dataset/util/services.cc b/mindspore/ccsrc/dataset/util/services.cc index 6516deea41..755d217311 100644 --- a/mindspore/ccsrc/dataset/util/services.cc +++ b/mindspore/ccsrc/dataset/util/services.cc @@ -22,11 +22,11 @@ #include #endif #include +#include "dataset/engine/cache/cache_server.h" #include "dataset/util/circular_pool.h" #include "dataset/util/random.h" #include "dataset/util/task_manager.h" -#define SLOT_TASK_MGR 0 namespace mindspore { namespace dataset { std::unique_ptr Services::instance_ = nullptr; @@ -61,15 +61,25 @@ std::string Services::GetUniqueID() { TaskManager &Services::getTaskMgrInstance() { Services &sm = GetInstance(); - return *(static_cast(sm.sa_[SLOT_TASK_MGR])); + return *(static_cast(sm.sa_[kSlotTaskMgr_])); +} + +CacheServer &Services::getCacheServer() { + Services &sm = GetInstance(); + return *(static_cast(sm.sa_[kSlotCacheMgr_])); } Status Services::CreateAllInstances() { // In order, TaskMgr, BufferMgr Status rc; - sa_[SLOT_TASK_MGR] = new (&rc, pool_) TaskManager(); + sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); RETURN_IF_NOT_OK(rc); - rc = sa_[SLOT_TASK_MGR]->ServiceStart(); + rc = sa_[kSlotTaskMgr_]->ServiceStart(); + RETURN_IF_NOT_OK(rc); + // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers + sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); + RETURN_IF_NOT_OK(rc); + rc = sa_[kSlotCacheMgr_]->ServiceStart(); return rc; } @@ -83,8 +93,14 @@ Services::Services() : pool_(nullptr), sa_{nullptr} { Services::~Services() noexcept { try { // In reverse order - TaskManager *tm = static_cast(sa_[SLOT_TASK_MGR]); - if (tm) { + CacheServer *cs = static_cast(sa_[kSlotCacheMgr_]); + if (cs != nullptr) { + (void)cs->ServiceStop(); + cs->~CacheServer(); + pool_->Deallocate(cs); + } + TaskManager *tm = static_cast(sa_[kSlotTaskMgr_]); + if (tm != nullptr) { (void)tm->ServiceStop(); tm->~TaskManager(); pool_->Deallocate(tm); diff --git a/mindspore/ccsrc/dataset/util/services.h b/mindspore/ccsrc/dataset/util/services.h index e19f44dccc..e82b3e47f1 100644 --- a/mindspore/ccsrc/dataset/util/services.h +++ b/mindspore/ccsrc/dataset/util/services.h @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { class TaskManager; - +class CacheServer; class Services { public: static Status CreateInstance() { @@ -61,6 +61,8 @@ class Services { static TaskManager &getTaskMgrInstance(); + static CacheServer &getCacheServer(); + std::shared_ptr GetServiceMemPool() { return pool_; } #if !defined(_WIN32) && !defined(_WIN64) @@ -87,7 +89,9 @@ class Services { // We use pointers here instead of unique_ptr because we // want to have ultimate control on the order of // construction and destruction. - static constexpr int kNumServices_ = 1; + static constexpr int kSlotTaskMgr_ = 0; + static constexpr int kSlotCacheMgr_ = 1; + static constexpr int kNumServices_ = 2; Service *sa_[kNumServices_]; Services(); diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 971915f27e..b2d26b41ee 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -24,6 +24,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler, Sampler +from .engine.cache_client import DatasetCache from .engine.serializer_deserializer import serialize, deserialize, show from .engine.graphdata import GraphData diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py new file mode 100644 index 0000000000..800c0dab1d --- /dev/null +++ b/mindspore/dataset/engine/cache_client.py @@ -0,0 +1,49 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cache client +""" + +import copy +from mindspore._c_dataengine import CacheClient + +class DatasetCache: + """ + A client to interface with tensor caching service + """ + + def __init__(self, session_id=None, size=None, spilling=False): + if session_id is None: + raise RuntimeError("Session generation is not implemented yet. session id required") + self.size = size if size is not None else 0 + if size < 0: + raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size)) + if not isinstance(spilling, bool): + raise ValueError( + "spilling argument for cache should be a boolean value but got: spilling={}".format(spilling)) + self.session_id = session_id + self.spilling = spilling + self.cache_client = CacheClient(session_id, size, spilling) + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_cache = cls.__new__(cls) + memodict[id(self)] = new_cache + new_cache.session_id = copy.deepcopy(self.session_id, memodict) + new_cache.spilling = copy.deepcopy(self.spilling, memodict) + new_cache.size = copy.deepcopy(self.size, memodict) + new_cache.cache_client = self.cache_client + return new_cache diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 6ad4abe052..c1ef6a9922 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 + check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -386,7 +386,7 @@ class Dataset: @check_map def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None, python_multiprocessing=False): + num_parallel_workers=None, python_multiprocessing=False, cache=None): """ Apply each operation in operations to this dataset. @@ -427,6 +427,7 @@ class Dataset: parallel (default=None, the value from the config will be used). python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This option could be beneficial if the python operation is computational heavy (default=False). + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Returns: MapDataset, dataset after mapping operation. @@ -541,7 +542,7 @@ class Dataset: >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) """ return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, - python_multiprocessing) + python_multiprocessing, cache) @check_filter def filter(self, predicate, input_columns=None, num_parallel_workers=1): @@ -1868,13 +1869,14 @@ class MapDataset(DatasetOp): in parallel (default=None). python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This option could be beneficial if the python operation is computational heavy (default=False). + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Raises: ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. """ def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None, python_multiprocessing=False): + num_parallel_workers=None, python_multiprocessing=False, cache=None): super().__init__(num_parallel_workers) self.children.append(input_dataset) if input_columns is not None and not isinstance(input_columns, list): @@ -1886,6 +1888,7 @@ class MapDataset(DatasetOp): if output_columns is not None and not isinstance(output_columns, list): output_columns = [output_columns] self.output_columns = output_columns + self.cache = cache self.columns_order = columns_order if self.input_columns and self.output_columns \ @@ -1904,6 +1907,7 @@ class MapDataset(DatasetOp): args["operations"] = self.operations args["output_columns"] = self.output_columns args["columns_order"] = self.columns_order + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -1929,6 +1933,7 @@ class MapDataset(DatasetOp): new_op.parent = copy.deepcopy(self.parent, memodict) new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) + new_op.cache = copy.deepcopy(self.cache, memodict) new_op.operations = self.operations return new_op @@ -2346,7 +2351,7 @@ class RangeDataset(MappableDataset): return False -def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): +def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): """ Create sampler based on user input. @@ -2356,7 +2361,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): shuffle (bool): Shuffle. num_shards (int): Number of shard for sharding. shard_id (int): Shard ID. + non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). """ + if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): + return None + if input_sampler is not None: # If the user provided a sampler, then it doesn't matter what the other args are because # we are being asked specifically to use the given sampler. @@ -2369,7 +2378,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, samplers.WeightedRandomSampler, samplers.Sampler)) and - (num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)): + (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): raise ValueError( 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' ' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle)) @@ -2458,6 +2467,7 @@ class ImageFolderDatasetV2(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Raises: RuntimeError: If sampler and shuffle are specified at the same time. @@ -2482,7 +2492,7 @@ class ImageFolderDatasetV2(MappableDataset): @check_imagefolderdatasetv2 def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, extensions=None, class_indexing=None, - decode=False, num_shards=None, shard_id=None): + decode=False, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir @@ -2494,6 +2504,7 @@ class ImageFolderDatasetV2(MappableDataset): self.decode = decode self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache def get_args(self): args = super().get_args() @@ -2506,6 +2517,7 @@ class ImageFolderDatasetV2(MappableDataset): args["decode"] = self.decode args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -3251,6 +3263,7 @@ class TFRecordDataset(SourceDataset): argument should be specified only when num_shards is also specified. shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number of rows of each shard may be not equal. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Examples: >>> import mindspore.dataset as ds >>> import mindspore.common.dtype as mstype @@ -3268,7 +3281,7 @@ class TFRecordDataset(SourceDataset): @check_tfrecorddataset def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, - shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): + shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None): super().__init__(num_parallel_workers) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() @@ -3280,6 +3293,7 @@ class TFRecordDataset(SourceDataset): self.schema = schema self.columns_list = columns_list self.num_samples = num_samples + self.cache = cache if schema_obj is not None and num_samples is None: self.num_samples = schema_obj.num_rows @@ -3295,6 +3309,14 @@ class TFRecordDataset(SourceDataset): else: self.shuffle_level = shuffle self.shuffle_files = True + + # The TF record dataset does not directly support a sampler. It has provided sampling arguments + # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in + # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. + sampler_shuffle = self.shuffle_files + sampler = None + self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, + non_mappable=True) self.shard_equal_rows = shard_equal_rows def get_args(self): @@ -3318,6 +3340,8 @@ class TFRecordDataset(SourceDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id args["shard_equal_rows"] = self.shard_equal_rows + args["cache"] = self.cache.cache_client if self.cache is not None else None + args["sampler"] = self.sampler return args def get_dataset_size(self, estimate=False): @@ -3803,43 +3827,61 @@ class RandomDataset(SourceDataset): A source dataset that generates random data. Args: - num_samples (int): number of samples to generate. + total_rows (int): number of rows for the dataset to generate (default=None, number of rows is random) schema (str or Schema, optional): Path to the json schema file or schema object (default=None). If the schema is not provided, the random dataset generates a random schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) + num_samples (int): number of samples to draw from the total. (default=None, which means all rows) num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) + shuffle (bool, optional): Whether or not to perform shuffle on the dataset + (default=None, expected order behavior shown in the table). + num_shards (int, optional): Number of shards that the dataset should be divided + into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. """ - def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None): + @check_random_dataset + def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, + cache=None, shuffle=None, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) schema_obj = None if (schema is not None) and (not isinstance(schema, Schema)): schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it self.schema = schema self.columns_list = columns_list - if schema_obj is not None and num_samples is None: - self.num_samples = schema_obj.num_rows - elif num_samples is None: - self.num_samples = 0 + sampler = None + self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True) + self.num_samples = num_samples + self.cache = cache + if schema_obj is not None and total_rows is None: + self.total_rows = schema_obj.num_rows + elif total_rows is None: + self.total_rows = 0 else: - self.num_samples = num_samples + self.total_rows = total_rows + self.num_shards = num_shards + self.shard_id = shard_id + self.shuffle_level = shuffle def get_args(self): args = super().get_args() if self.schema is not None: if isinstance(self.schema, Schema): self.schema.datasetType = 'Random' - if self.num_samples is not None: - self.schema.num_rows = self.num_samples + if self.total_rows is not None: + self.schema.num_rows = self.total_rows args["schema_json_string"] = self.schema.to_json() else: args["schema_file_path"] = self.schema args["schema"] = self.schema - if self.columns_list is not None: - args["columns_list"] = self.columns_list - if self.num_samples is not None: - args["num_samples"] = self.num_samples + args["columns_list"] = self.columns_list + args["num_samples"] = self.num_samples + args["total_rows"] = self.total_rows + args["cache"] = self.cache.cache_client if self.cache is not None else None + args["sampler"] = self.sampler return args def get_dataset_size(self): @@ -3849,18 +3891,29 @@ class RandomDataset(SourceDataset): Return: Number, number of batches. """ + + num_rows = CifarOp.get_num_rows(self.dataset_dir, True) + + rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() if rows_from_sampler is None: - return self.num_samples + return rows_per_shard - return min(rows_from_sampler, self.num_samples) + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): - return True + if self.shuffle_level is None: + return True + + return self.shuffle_level or self.sampler.is_shuffled() def is_sharded(self): - return False + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + class Schema: diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index a1b9e908f3..8fd3a2bb9b 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -173,7 +173,9 @@ def traverse(node): # num_samples, shard_id, num_shards, shuffle # These arguments get moved into the sampler itself, so they are no longer needed to # be set at the dataset level. - if 'sampler' in node_args.keys(): + # TF Record is a special case because it uses both the dataset and sampler arguments + # which is not decided until later during tree preparation phase. + if node_repr['op_type'] != 'TFRecordDataset' and 'sampler' in node_args.keys(): if 'num_samples' in node_repr.keys(): node_repr['num_samples'] = None if 'shuffle' in node_repr.keys(): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 26ee62b811..98d66e9764 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -29,10 +29,11 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis from . import datasets from . import samplers +from . import cache_client def check_imagefolderdatasetv2(method): - """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2).""" + """A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -58,7 +59,7 @@ def check_imagefolderdatasetv2(method): def check_mnist_cifar_dataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -81,7 +82,7 @@ def check_mnist_cifar_dataset(method): def check_manifestdataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -108,7 +109,7 @@ def check_manifestdataset(method): def check_tfrecorddataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -134,7 +135,7 @@ def check_tfrecorddataset(method): def check_vocdataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(VOCDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(VOCDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -175,7 +176,7 @@ def check_vocdataset(method): def check_cocodataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(CocoDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(CocoDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -211,7 +212,7 @@ def check_cocodataset(method): def check_celebadataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(CelebADataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -247,7 +248,7 @@ def check_celebadataset(method): def check_minddataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(MindDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -279,7 +280,7 @@ def check_minddataset(method): def check_generatordataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -344,6 +345,27 @@ def check_generatordataset(method): return new_method +def check_random_dataset(method): + """A wrapper that wraps a parameter checker to the original Dataset(RandomDataset).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows'] + nreq_param_bool = ['shuffle'] + nreq_param_list = ['columns_list'] + + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_list, param_dict, list) + + check_sampler_shuffle_shard_options(param_dict) + + return method(self, *args, **kwargs) + + return new_method + def check_pad_info(key, val): """check the key and value pair of pad_info in batch""" @@ -506,7 +528,7 @@ def check_map(method): @wraps(method) def new_method(self, *args, **kwargs): - [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing], _ = \ + [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \ parse_user_args(method, *args, **kwargs) nreq_param_columns = ['input_columns', 'output_columns'] @@ -516,6 +538,8 @@ def check_map(method): if num_parallel_workers is not None: check_num_parallel_workers(num_parallel_workers) type_check(python_multiprocessing, (bool,), "python_multiprocessing") + if cache is not None: + type_check(cache, (cache_client.DatasetCache,), "cache") for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]): if param is not None: @@ -720,7 +744,7 @@ def check_add_column(method): def check_cluedataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -750,7 +774,7 @@ def check_cluedataset(method): def check_textfiledataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -823,7 +847,7 @@ def check_gnn_graphdata(method): def check_gnn_get_all_nodes(method): - """A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -836,7 +860,7 @@ def check_gnn_get_all_nodes(method): def check_gnn_get_all_edges(method): - """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_all_edges` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -849,7 +873,7 @@ def check_gnn_get_all_edges(method): def check_gnn_get_nodes_from_edges(method): - """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -862,7 +886,7 @@ def check_gnn_get_nodes_from_edges(method): def check_gnn_get_all_neighbors(method): - """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -877,7 +901,7 @@ def check_gnn_get_all_neighbors(method): def check_gnn_get_sampled_neighbors(method): - """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -905,7 +929,7 @@ def check_gnn_get_sampled_neighbors(method): def check_gnn_get_neg_sampled_neighbors(method): - """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -921,7 +945,7 @@ def check_gnn_get_neg_sampled_neighbors(method): def check_gnn_random_walk(method): - """A wrapper that wrap a parameter checker to the GNN `random_walk` function.""" + """A wrapper that wraps a parameter checker to the GNN `random_walk` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -968,7 +992,7 @@ def check_aligned_list(param, param_name, member_type): def check_gnn_get_node_feature(method): - """A wrapper that wrap a parameter checker to the GNN `get_node_feature` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_node_feature` function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -1012,7 +1036,7 @@ def check_gnn_get_edge_feature(method): def check_numpyslicesdataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset).""" @wraps(method) def new_method(self, *args, **kwargs): diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index 988d2f2118..a93d569810 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -39,7 +39,7 @@ def check_unique_list_of_words(words, arg_name): def check_lookup(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -56,7 +56,7 @@ def check_lookup(method): def check_from_file(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -74,7 +74,7 @@ def check_from_file(method): def check_from_list(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -97,7 +97,7 @@ def check_from_list(method): def check_from_dict(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -285,7 +285,7 @@ def check_bert_tokenizer(method): def check_from_dataset(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -328,7 +328,7 @@ def check_from_dataset(method): def check_ngram(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index 078845227d..4cb6613359 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -114,7 +114,7 @@ def check_erasing_value(value): def check_crop(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wraps a parameter checker to the original function(crop operation).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -127,7 +127,7 @@ def check_crop(method): def check_resize_interpolation(method): - """A wrapper that wrap a parameter checker to the original function(resize interpolation operation).""" + """A wrapper that wraps a parameter checker to the original function(resize interpolation operation).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -142,7 +142,7 @@ def check_resize_interpolation(method): def check_resize(method): - """A wrapper that wrap a parameter checker to the original function(resize operation).""" + """A wrapper that wraps a parameter checker to the original function(resize operation).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -155,7 +155,7 @@ def check_resize(method): def check_random_resize_crop(method): - """A wrapper that wrap a parameter checker to the original function(random resize crop operation).""" + """A wrapper that wraps a parameter checker to the original function(random resize crop operation).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -178,7 +178,7 @@ def check_random_resize_crop(method): def check_prob(method): - """A wrapper that wrap a parameter checker(check the probability) to the original function.""" + """A wrapper that wraps a parameter checker(check the probability) to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -192,7 +192,7 @@ def check_prob(method): def check_normalize_c(method): - """A wrapper that wrap a parameter checker to the original function(normalize operation written in C++).""" + """A wrapper that wraps a parameter checker to the original function(normalize operation written in C++).""" @wraps(method) def new_method(self, *args, **kwargs): @@ -205,7 +205,7 @@ def check_normalize_c(method): def check_normalize_py(method): - """A wrapper that wrap a parameter checker to the original function(normalize operation written in Python).""" + """A wrapper that wraps a parameter checker to the original function(normalize operation written in Python).""" @wraps(method) def new_method(self, *args, **kwargs): diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc index 7a3b6d552b..385b327768 100644 --- a/tests/ut/cpp/dataset/c_api_test.cc +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -738,7 +738,7 @@ TEST_F(MindDataTestPipeline, TestProjectMap) { EXPECT_TRUE(ds != nullptr); // Create a Project operation on ds - std::vector column_project = {"label"}; + std::vector column_project = {"image"}; ds = ds->Project(column_project); EXPECT_TRUE(ds != nullptr); diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc new file mode 100644 index 0000000000..a31a8f8ddf --- /dev/null +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -0,0 +1,579 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "dataset/core/client.h" +#include "dataset/engine/cache/cache_client.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/engine/datasetops/cache_op.h" +#include "dataset/engine/datasetops/cache_lookup_op.h" +#include "dataset/engine/datasetops/cache_merge_op.h" +#include "dataset/engine/datasetops/source/image_folder_op.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "dataset/util/storage_container.h" // lint !e322 +#include "dataset/engine/datasetops/source/random_data_op.h" +#include "dataset/engine/data_schema.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::dataset::CacheClient; +using mindspore::dataset::TaskGroup; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +class MindDataTestCacheOp : public UT::DatasetOpTesting { + public: + void SetUp() override { + DatasetOpTesting::SetUp(); + GlobalInit(); + } +}; + +TEST_F(MindDataTestCacheOp, TestCacheServer) { + Status rc; + CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true + // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. + rc = myClient.CreateCache(1, true); + EXPECT_TRUE(rc.IsOk()); + std::cout << myClient << std::endl; + + // Create a schema using the C api's + int32_t rank = 0; // not used + std::unique_ptr testSchema = std::make_unique(); + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + std::unordered_map map; + rc = testSchema->GetColumnNameMap(&map); + EXPECT_TRUE(rc.IsOk()); + + // Test the CacheSchema api + rc = myClient.CacheSchema(map); + EXPECT_TRUE(rc.IsOk()); + + // Create a tensor, take a snapshot and restore it back, and compare. + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + t->SetItemAt({0, 0}, 1); + t->SetItemAt({0, 1}, 2); + t->SetItemAt({0, 2}, 3); + t->SetItemAt({1, 0}, 4); + t->SetItemAt({1, 1}, 5); + t->SetItemAt({1, 2}, 6); + std::cout << *t << std::endl; + TensorTable tbl; + TensorRow row; + row.push_back(t); + int64_t row_id; + rc = myClient.WriteRow(row, &row_id); + EXPECT_TRUE(rc.IsOk()); + + // Switch off build phase. + rc = myClient.BuildPhaseDone(); + EXPECT_TRUE(rc.IsOk()); + + // Now restore from cache. + row.clear(); + rc = myClient.GetRows({row_id}, &tbl); + row = tbl.front(); + EXPECT_TRUE(rc.IsOk()); + auto r = row.front(); + std::cout << *r << std::endl; + // Compare + bool cmp = (*t == *r); + EXPECT_TRUE(cmp); + + // Get back the schema and verify + std::unordered_map map_out; + rc = myClient.FetchSchema(&map_out); + EXPECT_TRUE(rc.IsOk()); + cmp = (map_out == map); + EXPECT_TRUE(cmp); + + // Test Purge and Destroy + rc = myClient.PurgeCache(); + EXPECT_TRUE(rc.IsOk()); + rc = myClient.DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { + // Clear the rc of the master thread if any + (void)TaskManager::GetMasterThreadRc(); + TaskGroup vg; + Status rc; + CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true + // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. + rc = myClient.CreateCache(1, true); + EXPECT_TRUE(rc.IsOk()); + std::cout << myClient << std::endl; + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + t->SetItemAt({0, 0}, 1); + t->SetItemAt({0, 1}, 2); + t->SetItemAt({0, 2}, 3); + t->SetItemAt({1, 0}, 4); + t->SetItemAt({1, 1}, 5); + t->SetItemAt({1, 2}, 6); + TensorTable tbl; + TensorRow row; + row.push_back(t); + // Cache tensor row t 5000 times using 10 threads. + for (auto k = 0; k < 10; ++k) { + Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status { + TaskManager::FindMe()->Post(); + for (auto i = 0; i < 500; i++) { + RETURN_IF_NOT_OK(myClient.WriteRow(row)); + } + return Status::OK(); + }); + EXPECT_TRUE(vg_rc.IsOk()); + } + ASSERT_TRUE(vg.join_all().IsOk()); + ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk()); + rc = myClient.BuildPhaseDone(); + ASSERT_TRUE(rc.IsOk()); + // Get statistics from the server. + CacheClient::ServiceStat stat{}; + rc = myClient.GetStat(&stat); + ASSERT_TRUE(rc.IsOk()); + std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached + << "\n"; + // Expect there are 5000 rows there. + EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1); + // Get them all back using row id and compare with tensor t. + for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) { + tbl.clear(); + row.clear(); + rc = myClient.GetRows({i}, &tbl); + EXPECT_TRUE(rc.IsOk()); + row = tbl.front(); + auto r = row.front(); + bool cmp = (*t == *r); + EXPECT_TRUE(cmp); + } + rc = myClient.DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +// Simple test with a repeated cache op over random data producer +// +// RepeatOp +// | +// CacheOp +// | +// RandomDataOp +// +TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { + Status rc; + int32_t rank = 0; // not used + MS_LOG(INFO) << "UT test TestRandomDataCache1"; + // Start with an empty execution tree + auto myTree = std::make_shared(); + + // Create a schema using the C api's + std::unique_ptr testSchema = std::make_unique(); + + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + // RandomDataOp + std::shared_ptr myRandomDataOp; + rc = RandomDataOp::Builder() + .SetRowsPerBuffer(4) + .SetNumWorkers(4) + .SetDataSchema(std::move(testSchema)) + .SetTotalRows(50) // 50 samples for now + .Build(&myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + + // CacheOp + // size of 0, spilling is true + std::shared_ptr myClient = std::make_shared(1, 0, true); + std::shared_ptr myCacheOp; + + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + rc = CacheOp::Builder() + .SetNumWorkers(5) + .SetClient(myClient) + .SetRowsPerBuffer(4) + .SetSampler(std::move(seq_sampler)) + .Build(&myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + // Assign tree relations and root + rc = myRepeatOp->AddChild(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myCacheOp->AddChild(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration"; + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + + // quick check to see what tree looks like + std::ostringstream ss; + ss << *myTree; // some funny const error if I try to write directly to ms log stream + MS_LOG(INFO) << "Here's the tree:\n" << ss.str(); + + std::cout << *myClient << std::endl; + + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + // Don't display these rows, just count them + MS_LOG(INFO) << "Row fetched #: " << rowCount; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + rowCount++; + } + ASSERT_EQ(rowCount, 200); + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +//// Simple test with a repeated cache op over random data producer. +//// This one will exceed memory and require a spill. +//// +//// RepeatOp +//// | +//// CacheOp +//// | +//// RandomDataOp +//// +TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { + Status rc; + int32_t rank = 0; // not used + MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; + // Start with an empty execution tree + auto myTree = std::make_shared(); + + // Create a schema using the C api's + std::unique_ptr testSchema = std::make_unique(); + + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + // RandomDataOp + std::shared_ptr myRandomDataOp; + rc = RandomDataOp::Builder() + .SetRowsPerBuffer(2) + .SetNumWorkers(4) + .SetDataSchema(std::move(testSchema)) + .SetTotalRows(10) + .Build(&myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + + // CacheOp + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + std::shared_ptr myClient = std::make_shared(1, 4, true); + std::shared_ptr myCacheOp; + rc = CacheOp::Builder() + .SetNumWorkers(4) + .SetClient(myClient) + .SetRowsPerBuffer(3) + .SetSampler(std::move(seq_sampler)) + .Build(&myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + // Assign tree relations and root + rc = myRepeatOp->AddChild(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myCacheOp->AddChild(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration"; + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + + std::cout << *myClient << std::endl; + + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + // Don't display these rows, just count them + MS_LOG(INFO) << "Row fetched #: " << rowCount; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + rowCount++; + } + ASSERT_EQ(rowCount, 40); + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { + Status rc; + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + + std::shared_ptr myClient = std::make_shared(1, 0, true); + + std::shared_ptr myMergeOp; + rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build( + &myMergeOp); + EXPECT_TRUE(rc.IsOk()); + + std::shared_ptr myLookupOp; + rc = CacheLookupOp::Builder() + .SetNumWorkers(3) + .SetOpConnectorSize(3) + .SetClient(myClient) + .SetSampler(seq_sampler) + .Build(&myLookupOp); + EXPECT_TRUE(rc.IsOk()); + + std::shared_ptr so; + ImageFolderOp::Builder builder; + builder.SetSampler(myLookupOp) + .SetOpConnectorSize(3) + .SetNumWorkers(3) + .SetRowsPerBuffer(2) + .SetExtensions({".jpg", ".JPEG"}) + .SetRecursive(true) + .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); + rc = builder.Build(&so); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + auto myTree = std::make_shared(); + rc = myTree->AssociateNode(so); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myLookupOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myMergeOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + rc = myRepeatOp->AddChild(myMergeOp); + EXPECT_TRUE(rc.IsOk()); + rc = myMergeOp->AddChild(myLookupOp); + EXPECT_TRUE(rc.IsOk()); + rc = myMergeOp->AddChild(so); + EXPECT_TRUE(rc.IsOk()); + + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + if (rc.IsError()) { + std::cout << rc << std::endl; + break; + } + rowCount++; + } + ASSERT_EQ(rowCount, 176); + std::cout << "Row count : " << rowCount << std::endl; + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +//// Simple test with a repeated cache op over random data producer. +//// The difference in this one is that you do not add the sampler to the cache op directly. +//// Instead, the sampler is added as part of the leaf op construction. Then, the prepare +//// phase will pull this up from the leaf and into the cache. +//// It removes the sampler from the leaf op, which doesn't make sense there anyway for +//// the RandomDataOp which doesn't support sampling without a cache. +//// +//// RepeatOp +//// | +//// CacheOp +//// | +//// RandomDataOp +//// +TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { + Status rc; + int32_t rank = 0; // not used + MS_LOG(INFO) << "UT test TestCacheInheritSampler"; + + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + + // Start with an empty execution tree + auto myTree = std::make_shared(); + + // Create a schema using the C api's + std::unique_ptr testSchema = std::make_unique(); + + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + // RandomDataOp + std::shared_ptr myRandomDataOp; + rc = RandomDataOp::Builder() + .SetRowsPerBuffer(2) + .SetNumWorkers(4) + .SetDataSchema(std::move(testSchema)) + .SetTotalRows(10) + .SetSampler(std::move(seq_sampler)) + .Build(&myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + + // CacheOp + std::shared_ptr myClient = std::make_shared(1, 4, true); + std::shared_ptr myCacheOp; + rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + // Assign tree relations and root + rc = myRepeatOp->AddChild(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myCacheOp->AddChild(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration"; + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + + std::cout << *myClient << std::endl; + + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + // Don't display these rows, just count them + MS_LOG(INFO) << "Row fetched #: " << rowCount; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + rowCount++; + } + ASSERT_EQ(rowCount, 40); + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} diff --git a/tests/ut/data/dataset/golden/cache_map_01_result.npz b/tests/ut/data/dataset/golden/cache_map_01_result.npz new file mode 100644 index 0000000000..7cff9ded88 Binary files /dev/null and b/tests/ut/data/dataset/golden/cache_map_01_result.npz differ diff --git a/tests/ut/data/dataset/golden/cache_map_02_result.npz b/tests/ut/data/dataset/golden/cache_map_02_result.npz new file mode 100644 index 0000000000..7cff9ded88 Binary files /dev/null and b/tests/ut/data/dataset/golden/cache_map_02_result.npz differ diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py new file mode 100644 index 0000000000..0e42b422aa --- /dev/null +++ b/tests/ut/python/dataset/test_cache_map.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing cache operator with mappable datasets +""" +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as c_vision +from mindspore import log as logger +from util import save_and_check_md5 + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +GENERATE_GOLDEN = False + +def test_cache_map_basic1(): + """ + Test mappable leaf with cache op right over the leaf + + Repeat + | + Map(decode) + | + Cache + | + ImageFolder + """ + + logger.info("Test cache map basic 1") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + filename = "cache_map_01_result.npz" + save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) + + logger.info("test_cache_map_basic1 Ended.\n") + + +def test_cache_map_basic2(): + """ + Test mappable leaf with the cache op later in the tree above the map(decode) + + Repeat + | + Cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map basic 2") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + filename = "cache_map_02_result.npz" + save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) + + logger.info("test_cache_map_basic2 Ended.\n") + + +def test_cache_map_basic3(): + """ + Test a repeat under mappable cache + + Cache + | + Map(decode) + | + Repeat + | + ImageFolder + """ + + logger.info("Test cache basic 3") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.repeat(4) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info('test_cache_basic3 Ended.\n') + + +def test_cache_map_failure1(): + """ + Test nested cache (failure) + + Repeat + | + Cache + | + Map(decode) + | + Cache + | + ImageFolder + + """ + logger.info("Test cache failure 1") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + try: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Nested cache operations is not supported!" in str(e) + + assert num_iter == 0 + logger.info('test_cache_failure1 Ended.\n') + +if __name__ == '__main__': + test_cache_map_basic1() + test_cache_map_basic2() + test_cache_map_basic3() + test_cache_map_failure1() diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py new file mode 100644 index 0000000000..39e00c0621 --- /dev/null +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -0,0 +1,429 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing cache operator with non-mappable datasets +""" +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as c_vision +from mindspore import log as logger + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + +GENERATE_GOLDEN = False + +def test_cache_nomap_basic1(): + """ + A random dataset (a non mappable dataset) with a cache over it just after the leaf + """ + + logger.info("Test cache nomap basic 1") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + # create a cache. arbitrary session_id for now + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # User-created sampler here + ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for data in ds1.create_dict_iterator(): + logger.info("printing the label: {}".format(data["label"])) + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 40 + logger.info("test_cache_nomap_basic1 Ended.\n") + + +def test_cache_nomap_basic2(): + """ + A random dataset (a non mappable dataset) with a cache over it just after the leaf + """ + + logger.info("Test cache nomap basic 2") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + # create a cache. arbitrary session_id for now + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler: + # num_samples, shuffle, num_shards, shard_id + # In this case, the presence of num_samples chooses a sampler. + ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat(2) + + num_iter = 0 + for data in ds1.create_dict_iterator(): + logger.info("printing the label: {}".format(data["label"])) + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 40 + logger.info("test_cache_nomap_basic2 Ended.\n") + + +def test_cache_nomap_basic3(): + """ + A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf + + Repeat + | + Map(decode) + | + Cache + | + TFReader + """ + + logger.info("Test cache nomap basic 3") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic3 Ended.\n") + + +def test_cache_nomap_basic4(): + """ + A TF reader dataset (a non mappable dataset) with a map decode and cache after it + Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf. + But, if there's a cache later, that shuffle becomes invalid and should be removed. + + Repeat + | + Cache + | + Map(decode) + | + TFReader + """ + + logger.info("Test cache nomap basic 4") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache + # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will + # explicitly give the global option, even though it's the default in python. + # But, when caching is added in the ascendent tree above TF, we do global shuffling + # through the sampler over the cache, not by the shuffle op. In that case, tree prepare + # will remove the shuffle op that got injected by the initial tree creation. + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL) + decode_op = c_vision.Decode() + + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic4 Ended.\n") + + +def test_cache_nomap_basic5(): + """ + A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf + Same as test 3, but this one does not have shuffle arg, causing tf to default to global + shuffle which attempts to inject a shuffle operator. However, since there is a cache + we do not need global shuffle, so the shuffle will not be built. It ends up being + identical to test basic 3, however we arrive at the same tree in different codepaths + (if there was no cache, then the shuffle IS built) + + Repeat + | + Map(decode) + | + Cache + | + TFReader + """ + + logger.info("Test cache nomap basic 5") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic5 Ended.\n") + + +def test_cache_nomap_basic6(): + """ + A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf + In this one, the tf dataset will be given sharding configuration, however since a cache is + used, the tree prepare should undo the sharding configuration and instead, a distributed + sampler will be chosen with the same shard config. + + Repeat + | + Map(decode) + | + Cache + | + TFReader + """ + + logger.info("Test cache nomap basic 6") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # With only 3 records shard into 3, we expect only 1 record returned for this shard + # However, the sharding will be done by the sampler, not by the tf record leaf node + # In this case, it is a row-based sharding, not the file-based sharding that would happen if + # there was not any cache. + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 4 + logger.info("test_cache_nomap_basic6 Ended.\n") + + +def test_cache_nomap_basic7(): + """ + A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by + map. + In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the + tf reader, but since a cache is given, it will choose not to. + + Repeat + | + Map(decode) + | + cache + | + TFReader + """ + + logger.info("Test cache nomap basic 7") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic7 Ended.\n") + + +def test_cache_nomap_allowed_share1(): + """ + It is allowed to share the cache between the following two trees: + + Repeat Shuffle + | | + Cache Cache + | | + TFReader TFReader + """ + + logger.info("Test cache nomap allowed share 1") + + ds.config.set_seed(1) + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) + ds1 = ds1.repeat(4) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) + ds2 = ds2.shuffle(buffer_size=2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert num_iter == 12 + logger.info("Number of data in ds1: {} ".format(num_iter)) + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + assert num_iter == 3 + logger.info("test_cache_nomap_allowed_share1 Ended.\n") + + +def test_cache_nomap_allowed_share2(): + """ + It is allowed to share the cache between the following two trees (with map decode): + + Repeat Shuffle + | | + Cache Cache + | | + Map(decode) Map(decode) + | | + TFReader TFReader + """ + + logger.info("Test cache nomap allowed share 2") + + ds.config.set_seed(1) + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True) + decode_op = c_vision.Decode() + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds2 = ds2.shuffle(buffer_size=2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + assert num_iter == 3 + logger.info("test_cache_nomap_allowed_share2 Ended.\n") + + +def test_cache_nomap_allowed_share3(): + """ + It is allowed to share the cache between the following two trees (different shard ids): + + Repeat Repeat + | | + Cache Cache + | | + TFReader(shard_id = 0) TFReader(shard_id = 1) + """ + + logger.info("Test cache nomap allowed share 3") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"] + ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache) + ds1 = ds1.repeat(4) + + ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache) + ds2 = ds2.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + assert num_iter == 12 + logger.info("test_cache_nomap_allowed_share3 Ended.\n") + + +def test_cache_nomap_disallowed_share1(): + """ + It is not allowed to share the cache between the following two trees: + + Cache Cache + | | + Map(decode) Map(rescale) + | | + TFReader TFReader + """ + + logger.info("Test cache nomap disallowed share1") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + decode_op = c_vision.Decode() + rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0) + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds2 = ds2.map(input_columns=["image"], operations=rescale_op, cache=some_cache) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 3 + + try: + sum([1 for _ in ds2]) + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Attempt to re-use a cache for a different tree!" in str(e) + + logger.info("test_cache_nomap_disallowed_share1 Ended.\n") + + +if __name__ == '__main__': + test_cache_nomap_basic1() + test_cache_nomap_basic2() + test_cache_nomap_basic3() + test_cache_nomap_basic4() + test_cache_nomap_basic5() + test_cache_nomap_basic6() + test_cache_nomap_basic7() + test_cache_nomap_allowed_share1() + test_cache_nomap_allowed_share2() + test_cache_nomap_allowed_share3() + test_cache_nomap_disallowed_share1() diff --git a/tests/ut/python/dataset/test_random_dataset.py b/tests/ut/python/dataset/test_random_dataset.py index 4d50be254c..56a2a93113 100644 --- a/tests/ut/python/dataset/test_random_dataset.py +++ b/tests/ut/python/dataset/test_random_dataset.py @@ -16,17 +16,16 @@ import mindspore.common.dtype as mstype import mindspore.dataset as ds from mindspore import log as logger - # just a basic test with parallel random data op def test_randomdataset_basic1(): - logger.info("Test randomdataset basic") + logger.info("Test randomdataset basic 1") schema = ds.Schema() schema.add_column('image', de_type=mstype.uint8, shape=[2]) schema.add_column('label', de_type=mstype.uint8, shape=[1]) # apply dataset operations - ds1 = ds.RandomDataset(schema=schema, num_samples=50, num_parallel_workers=4) + ds1 = ds.RandomDataset(schema=schema, total_rows=50, num_parallel_workers=4) ds1 = ds1.repeat(4) num_iter = 0 @@ -36,8 +35,9 @@ def test_randomdataset_basic1(): logger.info("{} label: {}".format(num_iter, data["label"])) num_iter += 1 - logger.info("Number of data in ds1: ", num_iter) + logger.info("Number of data in ds1: {}".format(num_iter)) assert num_iter == 200 + logger.info("Test randomdataset basic 1 complete") # Another simple test @@ -49,10 +49,8 @@ def test_randomdataset_basic2(): shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) schema.add_column('label', de_type=mstype.uint8, shape=[1]) - # Make up about 10 samples - ds1 = ds.RandomDataset(schema=schema, num_samples=10, num_parallel_workers=1) - - # cache size allows for about 4 images since each image just a bit less than 1MB, after that we will have to spill + # Make up 10 rows + ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=1) ds1 = ds1.repeat(4) num_iter = 0 @@ -62,11 +60,31 @@ def test_randomdataset_basic2(): logger.info("printing the label: {}".format(data["label"])) num_iter += 1 - logger.info("Number of data in ds1: ", num_iter) + logger.info("Number of data in ds1: {}".format(num_iter)) assert num_iter == 40 + logger.info("Test randomdataset basic 2 complete") + + +# Another simple test +def test_randomdataset_basic3(): + logger.info("Test randomdataset basic 3") + + # Make up 10 samples, but here even the schema is randomly created + # The columns are named like this "c0", "c1", "c2" etc + # But, we will use a tuple iterator instead of dict iterator so the column names + # are not needed to iterate + ds1 = ds.RandomDataset(total_rows=10, num_parallel_workers=1) + ds1 = ds1.repeat(2) + + num_iter = 0 + for _ in ds1.create_tuple_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {}".format(num_iter)) + assert num_iter == 20 + logger.info("Test randomdataset basic 3 Complete") if __name__ == '__main__': test_randomdataset_basic1() test_randomdataset_basic2() - logger.info('test_randomdataset_basic Ended.\n') + test_randomdataset_basic3()