From d7b5f483459e66b084b9487c87b2e3bdc448a2bd Mon Sep 17 00:00:00 2001 From: hesham Date: Sat, 17 Oct 2020 02:34:31 -0400 Subject: [PATCH] Change c++ Iterator to use Runtime2 Change c++ Iterator to use Runtime Add Runtime header file Add treeConsumer header file --- .../ccsrc/minddata/dataset/api/iterator.cc | 98 ++--------- .../minddata/dataset/engine/CMakeLists.txt | 14 +- .../engine/consumers/python_tree_consumer.h | 49 ++++++ .../dataset/engine/consumers/tree_consumer.cc | 72 ++++++++ .../dataset/engine/consumers/tree_consumer.h | 154 ++++++++++++++++++ .../dataset/engine/runtime_context.cc | 25 +++ .../minddata/dataset/engine/runtime_context.h | 54 ++++++ .../minddata/dataset/engine/tree_adapter.cc | 2 - .../minddata/dataset/engine/tree_adapter.h | 3 + .../ccsrc/minddata/dataset/include/datasets.h | 2 + .../ccsrc/minddata/dataset/include/iterator.h | 17 +- .../ccsrc/minddata/dataset/util/status.h | 1 + 12 files changed, 389 insertions(+), 102 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/runtime_context.h diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index cfdfcb8bc8..75991adfe2 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -15,6 +15,7 @@ */ #include "minddata/dataset/include/iterator.h" #include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/consumers/tree_consumer.h" #include "minddata/dataset/include/datasets.h" namespace mindspore { @@ -23,7 +24,7 @@ namespace api { // Get the next row from the data pipeline. bool Iterator::GetNextRow(TensorMap *row) { - Status rc = iterator_->GetNextAsMap(row); + Status rc = consumer_->GetNextAsMap(row); if (rc.IsError()) { MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; row->clear(); @@ -34,100 +35,27 @@ bool Iterator::GetNextRow(TensorMap *row) { // Get the next row from the data pipeline. bool Iterator::GetNextRow(TensorVec *row) { - TensorRow tensor_row; - Status rc = iterator_->FetchNextTensorRow(&tensor_row); + Status rc = consumer_->GetNextAsVector(row); if (rc.IsError()) { MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; row->clear(); return false; } - // Generate a vector as return - row->clear(); - std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row)); return true; } // Shut down the data pipeline. -void Iterator::Stop() { - // Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_. - iterator_.reset(); - - // Release ownership of tree_ shared pointer. This will decrement the ref count. - tree_.reset(); -} - -// Function to build and launch the execution tree. +void Iterator::Stop() { runtime_context->Terminate(); } +// +//// Function to build and launch the execution tree. Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { - // One time init - Status rc; - rc = GlobalInit(); - RETURN_IF_NOT_OK(rc); - - // Instantiate the execution tree - tree_ = std::make_shared(); - - // Iterative BFS converting Dataset tree into runtime Execution tree. - std::queue, std::shared_ptr>> q; - - if (ds == nullptr) { - RETURN_STATUS_UNEXPECTED("Input is null pointer"); - } else { - // Convert the current root node. - auto root_ops = ds->Build(); - if (root_ops.empty()) { - RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); - } - - // Iterate through all the DatasetOps returned by Dataset's Build(), associate them - // with the execution tree and add the child and parent relationship between the nodes - // Note that some Dataset objects might return more than one DatasetOps - // e.g. MapDataset will return [ProjectOp, MapOp] if project_columns is set for MapDataset - std::shared_ptr prev_op = nullptr; - for (auto op : root_ops) { - RETURN_IF_NOT_OK(tree_->AssociateNode(op)); - if (prev_op != nullptr) { - RETURN_IF_NOT_OK(prev_op->AddChild(op)); - } - prev_op = op; - } - // Add the last DatasetOp to the queue to be BFS. - q.push(std::make_pair(ds, root_ops.back())); - - // Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes) - while (!q.empty()) { - auto node_pair = q.front(); - q.pop(); - // Iterate through all the direct children of the first element in our BFS queue - for (auto child : node_pair.first->children) { - auto child_ops = child->Build(); - if (child_ops.empty()) { - RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); - } - auto node_op = node_pair.second; - // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them - // with the execution tree and add the child and parent relationship between the nodes - // Note that some Dataset objects might return more than one DatasetOps - // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset - for (auto child_op : child_ops) { - RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); - RETURN_IF_NOT_OK(node_op->AddChild(child_op)); - node_op = child_op; - } - // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current - // execution tree) to the BFS queue - q.push(std::make_pair(child, child_ops.back())); - } - } - RETURN_IF_NOT_OK(tree_->AssignRoot(root_ops.front())); - } - - // Launch the execution tree. - RETURN_IF_NOT_OK(tree_->Prepare()); - tree_->Launch(); - iterator_ = std::make_unique(tree_); - RETURN_UNEXPECTED_IF_NULL(iterator_); - - return rc; + runtime_context = std::make_unique(); + RETURN_IF_NOT_OK(runtime_context->Init()); + auto consumer = std::make_unique(); + consumer_ = consumer.get(); + RETURN_IF_NOT_OK(consumer->Init(ds)); + runtime_context->AssignConsumer(std::move(consumer)); + return Status::OK(); } } // namespace api diff --git a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt index 688a802204..3c286fcab8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt @@ -10,12 +10,14 @@ endif () file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(engine OBJECT - execution_tree.cc - data_buffer.cc - data_schema.cc - dataset_iterator.cc - tree_adapter.cc - ) + execution_tree.cc + data_buffer.cc + data_schema.cc + dataset_iterator.cc + tree_adapter.cc + runtime_context.cc + consumers/tree_consumer.cc + ) if (ENABLE_PYTHON) target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h new file mode 100644 index 0000000000..5775dc1239 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/consumers/tree_consumer.h" + +namespace mindspore::dataset { + +/// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict +class PythonIterator : public IteratorConsumer { + /// Constructor + /// \param num_epochs number of epochs. Default to -1 (infinite epochs). + explicit PythonIterator(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {} + + /// Get the next row as a python dict + /// \param[out] output python dict + /// \return Status error code + Status GetNextAsMap(py::dict *output) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + /// Get the next row as a python dict + /// \param[out] output python dict + /// \return Status error code + Status GetNextAsList(py::list *output) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } +}; + +} // namespace mindspore::dataset +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc new file mode 100644 index 0000000000..814953b15e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/consumers/tree_consumer.h" +#include "minddata/dataset/engine/tree_adapter.h" + +namespace mindspore::dataset { + +Status IteratorConsumer::GetNextAsVector(std::vector *out) { + RETURN_UNEXPECTED_IF_NULL(out); + out->clear(); + + TensorRow res; + RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res)); + + // Return empty vector if there's no data + RETURN_OK_IF_TRUE(res.empty()); + + std::copy(res.begin(), res.end(), std::back_inserter(*out)); + return Status::OK(); +} +Status IteratorConsumer::GetNextAsMap(std::unordered_map *out_map) { + RETURN_UNEXPECTED_IF_NULL(out_map); + out_map->clear(); + + TensorRow res; + RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res)); + + // Return empty map if there's no data + RETURN_OK_IF_TRUE(res.empty()); + + // Populate the out map from the row and return it + for (const auto &colMap : tree_adapter_->GetColumnNameMap()) { + (*out_map)[colMap.first] = std::move(res[colMap.second]); + } + return Status::OK(); +} + +TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique(); } + +Status IteratorConsumer::Init(std::shared_ptr d) { + return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); +} +Status TreeConsumer::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } + +Status ToDevice::Init(std::shared_ptr d) { + // TODO(CRC): + // Get device ID from children look at get_distribution in python + // Add DeviceQue IR on top of dataset d + + return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); +} +} // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h new file mode 100644 index 0000000000..00cf7e184f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/tree_adapter.h" + +namespace mindspore::dataset { +// Forward declare +class TreeAdapter; + +namespace api { +class Dataset; +} + +/// A base class for tree consumers which would fetch rows from the tree pipeline +class TreeConsumer { + public: + /// Constructor that prepares an empty tree_adapter + TreeConsumer(); + /// Initializes the consumer, this involves constructing and preparing the tree. + /// \param d The dataset node that represent the root of the IR tree. + /// \return Status error code. + virtual Status Init(std::shared_ptr d); + + protected: + /// The class owns the tree_adapter that handles execution tree operations. + std::unique_ptr tree_adapter_; + /// Method to return the name of the consumer + /// \return string + virtual std::string Name() = 0; +}; + +/// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map +class IteratorConsumer : public TreeConsumer { + public: + /// Constructor which will call the base class default constructor. + /// \param num_epochs number of epochs. Default to -1 (infinite epochs). + explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} + + Status Init(std::shared_ptr d) override; + + /// Returns the next row in a vector format + /// \param[out] out std::vector of Tensors + /// \return Status error code + Status GetNextAsVector(std::vector *out); + + /// Returns the next row in as a map + /// \param[out] out std::map of string to Tensor + /// \return Status error code + Status GetNextAsMap(std::unordered_map *out); + + protected: + /// Method to return the name of the consumer + /// \return string + std::string Name() override { return "IteratorConsumer"; } + + private: + int32_t num_epochs_; +}; + +/// Consumer that iterates over the dataset and writes it to desk +class SaveToDesk : public TreeConsumer { + public: + /// Constructor which will call the base class default constructor. + /// \param dataset_path path the the dataset + /// \param num_files number of files. Default to 1 + /// \param dataset_type The format of the dataset. Default to "mindrecod". + explicit SaveToDesk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") + : TreeConsumer(), dataset_path_(dataset_path), num_files_(num_files), dataset_type_(dataset_type) {} + + /// Save the given dataset to MindRecord format on desk. This is a blocking method (i.e., after returning, all rows + /// would be written to desk) + /// \return Status error code + Status Save() { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); } + + private: + std::string dataset_path_; + int32_t num_files_; + std::string dataset_type_; +}; + +/// Consumer that iterates over the dataset and send it to a device +class ToDevice : public TreeConsumer { + public: + ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs) + : TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} + + Status Init(std::shared_ptr d) override; + + Status Send() { + // TODO(CRC): launch the tree + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status Stop() { + // TODO(CRC): Get root + call StopSend + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status Continue() { + // TODO(CRC): Get root + call StopSend + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + + private: + std::string device_type_; + bool send_epoch_end_; + int32_t num_epochs_; +}; + +/// Consumer that is used to get some pipeline information +class TreeGetters : public TreeConsumer { + Status GetDatasetSize(int32_t *size) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status GetBatchSize(int32_t *batch_size) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status GetRepeatCount(int32_t *repeat_count) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status GetNumClasses(int32_t *num_classes) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status GetOutputShapes(std::vector *shapes) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status GetOutputTypes(std::vector *types) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + Status GetOutputNames(std::vector *names) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } +}; + +} // namespace mindspore::dataset +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc new file mode 100644 index 0000000000..e82d6a1bd9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/runtime_context.h" +#include +#include +namespace mindspore::dataset { + +void RuntimeContext::AssignConsumer(std::unique_ptr tree_consumer) { + tree_consumer_ = std::move(tree_consumer); +} +} // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.h b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.h new file mode 100644 index 0000000000..855acc3fc0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ + +#include +#include +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/consumers/tree_consumer.h" + +namespace mindspore::dataset { +class TreeConsumer; + +/// Class the represents single runtime instance which can consume data from a data pipeline +class RuntimeContext { + public: + /// Default constructor + RuntimeContext() = default; + + /// Initialize the runtime, for now we just call the global init + /// \return Status error code + Status Init() { return GlobalInit(); } + + /// Method to terminate the runtime, this will not release the resources + /// \return Status error code + virtual Status Terminate() { return Status::OK(); } + + /// Set the tree consumer + /// \param tree_consumer to be assigned + void AssignConsumer(std::unique_ptr tree_consumer); + + /// Get the tree consumer + /// \return Raw pointer to the tree consumer. + TreeConsumer *GetConsumer() { return tree_consumer_.get(); } + + private: + std::unique_ptr tree_consumer_; +}; + +} // namespace mindspore::dataset +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index 993ce85914..11fe226784 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -26,8 +26,6 @@ Status TreeAdapter::BuildAndPrepare(std::shared_ptr root_ir, int32 // Check whether this function has been called before. If so, return fail CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built."); RETURN_UNEXPECTED_IF_NULL(root_ir); - // GlobalInit, might need to be moved to the proper place once RuntimeConext is complete - RETURN_IF_NOT_OK(GlobalInit()); // this will evolve in the long run tree_ = std::make_unique(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index 38907e82e7..3afbe56357 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -28,6 +28,9 @@ namespace mindspore { namespace dataset { +namespace api { +class Dataset; +} class TreeAdapter { public: TreeAdapter() = default; diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 527f050c4f..beada4515a 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -51,6 +51,8 @@ class Vocab; #endif namespace api { +class Dataset; +class Iterator; class TensorOperation; class SchemaObj; diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h index b6f476c54b..4c82fad574 100644 --- a/mindspore/ccsrc/minddata/dataset/include/iterator.h +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -17,10 +17,11 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ -#include #include -#include #include +#include +#include +#include "minddata/dataset/engine/runtime_context.h" #include "minddata/dataset/include/status.h" namespace mindspore { @@ -32,6 +33,8 @@ class DatasetIterator; class DatasetOp; class Tensor; +class RuntimeContext; +class IteratorConsumer; namespace api { class Dataset; @@ -43,7 +46,7 @@ using TensorVec = std::vector>; class Iterator { public: /// \brief Constructor - Iterator() = default; + Iterator() : consumer_(nullptr) {} /// \brief Destructor ~Iterator() = default; @@ -111,12 +114,8 @@ class Iterator { _Iterator end() { return _Iterator(nullptr); } private: - // Runtime tree. - // Use shared_ptr instead of unique_ptr because the DatasetIterator constructor takes in a shared_ptr type. - std::shared_ptr tree_; - - // Runtime iterator - std::unique_ptr iterator_; + std::unique_ptr runtime_context; + IteratorConsumer *consumer_; }; } // namespace api } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h index bc63de9870..b88f69bbe8 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.h +++ b/mindspore/ccsrc/minddata/dataset/util/status.h @@ -89,6 +89,7 @@ enum class StatusCode : char { kTimeOut = 14, kBuddySpaceFull = 15, kNetWorkError = 16, + kNotImplementedYet = 17, // Make this error code the last one. Add new error code above it. kUnexpectedError = 127 };