diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 785fbb58f8..639e47e1ad 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) { Status rc; // Build and launch tree - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; @@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data Status rc; // Build and launch tree auto ds = shared_from_this(); - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "CreateSaver failed." << rc; @@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared(); } int64_t Dataset::GetDatasetSize() { int64_t dataset_size; Status rc; - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; @@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() { std::vector Dataset::GetOutputTypes() { std::vector types; Status rc; - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; @@ -240,7 +240,7 @@ std::vector Dataset::GetOutputTypes() { std::vector Dataset::GetOutputShapes() { std::vector shapes; Status rc; - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; @@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() { int64_t num_classes; auto ds = shared_from_this(); Status rc; - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; @@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() { int64_t batch_size; auto ds = shared_from_this(); Status rc; - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; @@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() { int64_t repeat_count; auto ds = shared_from_this(); Status rc; - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; @@ -613,7 +613,7 @@ std::shared_ptr Dataset::BuildSentencePieceVocab( auto ds = std::make_shared(IRNode(), vocab, col_names, vocab_size, character_coverage, model_type, params); - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc; @@ -645,7 +645,7 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum auto ds = std::make_shared(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); - std::unique_ptr runtime_context = std::make_unique(); + std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc; diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index 5410ba4052..31a9ffaad5 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); } // Function to build and launch the execution tree. Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { - runtime_context_ = std::make_unique(); + runtime_context_ = std::make_unique(); RETURN_IF_NOT_OK(runtime_context_->Init()); auto consumer = std::make_unique(); consumer_ = consumer.get(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc b/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc index ae9ab4c275..c0c14bb415 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc @@ -19,9 +19,24 @@ namespace mindspore::dataset { -Status PythonRuntimeContext::Terminate() { +Status PythonRuntimeContext::Terminate() { return TerminateImpl(); } + +Status PythonRuntimeContext::TerminateImpl() { + CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); // Release GIL before joining all threads py::gil_scoped_release gil_release; return tree_consumer_->Terminate(); } + +PythonRuntimeContext::~PythonRuntimeContext() { + TerminateImpl(); + { + py::gil_scoped_acquire gil_acquire; + tree_consumer_.reset(); + } +} + +PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() { + return dynamic_cast(tree_consumer_.get()); +} } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h b/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h index 05116fed63..c07d741025 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h +++ b/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h @@ -24,25 +24,24 @@ #include "minddata/dataset/engine/runtime_context.h" namespace mindspore::dataset { -class RuntimeContext; +class NativeRuntimeContext; -/// Class that represents single runtime instance which can consume data from a data pipeline +/// Class that represents Python single runtime instance which can consume data from a data pipeline class PythonRuntimeContext : public RuntimeContext { public: /// Method to terminate the runtime, this will not release the resources /// \return Status error code Status Terminate() override; - // Safe destructing the tree that includes python objects - ~PythonRuntimeContext() { - Terminate(); - { - py::gil_scoped_acquire gil_acquire; - tree_consumer_.reset(); - } - } + /// Safe destructing the tree that includes python objects + ~PythonRuntimeContext() override; - PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast(tree_consumer_.get()); } + PythonIteratorConsumer *GetPythonConsumer(); + + private: + /// Internal function to perform the termination + /// \return Status error code + Status TerminateImpl(); }; } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc index 1d15495c7b..12582195ae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc @@ -22,4 +22,17 @@ namespace mindspore::dataset { void RuntimeContext::AssignConsumer(std::shared_ptr tree_consumer) { tree_consumer_ = std::move(tree_consumer); } +Status NativeRuntimeContext::Terminate() { return TerminateImpl(); } + +Status NativeRuntimeContext::TerminateImpl() { + CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); + return tree_consumer_->Terminate(); +} + +NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); } + +TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); } + +Status RuntimeContext::Init() { return GlobalInit(); } + } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.h b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.h index a2f3e17b47..bea1725594 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.h +++ b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.h @@ -23,8 +23,7 @@ namespace mindspore::dataset { class TreeConsumer; - -/// Class the represents single runtime instance which can consume data from a data pipeline +/// Class that represents single runtime instance which can consume data from a data pipeline class RuntimeContext { public: /// Default constructor @@ -32,11 +31,7 @@ class RuntimeContext { /// Initialize the runtime, for now we just call the global init /// \return Status error code - Status Init() { return GlobalInit(); } - - /// Method to terminate the runtime, this will not release the resources - /// \return Status error code - virtual Status Terminate() { return Status::OK(); } + Status Init(); /// Set the tree consumer /// \param tree_consumer to be assigned @@ -44,13 +39,32 @@ class RuntimeContext { /// Get the tree consumer /// \return Raw pointer to the tree consumer. - TreeConsumer *GetConsumer() { return tree_consumer_.get(); } + TreeConsumer *GetConsumer(); - ~RuntimeContext() { Terminate(); } + /// Method to terminate the runtime, this will not release the resources + /// \return Status error code + virtual Status Terminate() = 0; + + virtual ~RuntimeContext() = default; protected: std::shared_ptr tree_consumer_; }; +/// Class that represents C++ single runtime instance which can consume data from a data pipeline +class NativeRuntimeContext : public RuntimeContext { + public: + /// Method to terminate the runtime, this will not release the resources + /// \return Status error code + Status Terminate() override; + + ~NativeRuntimeContext() override; + + private: + /// Internal function to perform the termination + /// \return Status error code + Status TerminateImpl(); +}; + } // namespace mindspore::dataset #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h index 959a9c5191..1686f1552d 100644 --- a/mindspore/ccsrc/minddata/dataset/include/iterator.h +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -33,7 +33,7 @@ class DatasetIterator; class DatasetOp; class Tensor; -class RuntimeContext; +class NativeRuntimeContext; class IteratorConsumer; class Dataset; @@ -113,7 +113,7 @@ class Iterator { _Iterator end() { return _Iterator(nullptr); } private: - std::unique_ptr runtime_context_; + std::unique_ptr runtime_context_; IteratorConsumer *consumer_; }; } // namespace dataset