From bcf913fb0d1386b33d8df3d056d4e829723865ae Mon Sep 17 00:00:00 2001 From: hesham Date: Fri, 9 Oct 2020 00:25:45 -0400 Subject: [PATCH] Add temp fix gor generator Op when num_epochs=-1 --- .../engine/datasetops/epoch_ctrl_op.cc | 17 +++- .../dataset/engine/datasetops/repeat_op.cc | 31 ++++++- .../dataset/engine/datasetops/repeat_op.h | 10 ++ .../engine/datasetops/source/generator_op.cc | 18 +++- .../engine/datasetops/source/generator_op.h | 2 + .../dataset/engine/opt/post/repeat_pass.cc | 91 ++++++++++++++++++- .../dataset/engine/opt/post/repeat_pass.h | 19 ++++ 7 files changed, 182 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc index 5ec26311bf..1343fd4608 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc @@ -51,7 +51,15 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_; + out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ + << "\nLeaf Nodes in execution path:"; + if (!eoe_ops_.empty()) { + for (size_t i = 0; i < eoe_ops_.size(); i++) { + out << "\n Operator: " << eoe_ops_[i]->id(); + } + } else { + out << " None."; + } out << "\n\n"; } } @@ -86,6 +94,13 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) { // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. state_ = OpState::kDeOpIdle; + if (repeat_count_ != num_repeats_) { + for (auto &eoe_op : eoe_ops_) { + MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + } + return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 2e50c4d992..0a8591f5b1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -62,7 +62,15 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_; + out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ + << "\nLeaf Nodes in execution path:"; + if (!eoe_ops_.empty()) { + for (size_t i = 0; i < eoe_ops_.size(); i++) { + out << "\n Operator: " << eoe_ops_[i]->id(); + } + } else { + out << " None."; + } out << "\n\n"; } } @@ -107,9 +115,17 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { if (repeat_count_ == num_repeats_) { repeat_count_ = 0; state_ = OpState::kDeOpIdle; + return Status::OK(); } else { state_ = OpState::kDeOpRunning; } + + // Invoke a reset against the eoe nodes only. + for (auto &eoe_op : eoe_ops_) { + MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id(); + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + return Status::OK(); } @@ -138,6 +154,19 @@ int32_t RepeatOp::num_consumers() const { } } +// Drive reset actions if needed +Status RepeatOp::Reset() { + // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. + // In that case, we now have to bounce the reset down to our own eoe ops. + MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset."; + for (auto &eoe_op : eoe_ops_) { + MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id(); + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + state_ = OpState::kDeOpRunning; + return Status::OK(); +} + int32_t RepeatOp::num_producers() const { if (child_.empty() || child_[0] == nullptr) { MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index d2af23976e..35c3bfeea7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -129,6 +129,16 @@ class RepeatOp : public PipelineOp { /// \return The number of repeats that the user requested int32_t num_repeats() { return num_repeats_; } + /// \brief reset Op + /// \@return Status - The error code return + Status Reset() override; + + // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes + // \param[in] eoe_op The input leaf/eoe operator to add to the list + void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } + + std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. + protected: // The number of repeats that the user requested. // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 2e9ddddc2e..4bf2205744 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -186,6 +186,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) { Status GeneratorOp::operator()() { // Handshake with TaskManager to synchronize thread creation TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); std::unique_ptr fetched_buffer; bool eof = false; while (!eof) { @@ -227,8 +228,17 @@ Status GeneratorOp::operator()() { MS_LOG(DEBUG) << "Generator operator main execution loop complete."; eof = true; } else { - // Self-reset to start a new iteration - RETURN_IF_NOT_OK(Reset()); + // Waiting for repeatOp to start new epoch + // If Reset() is called first by repeat op, this wait() will return right away. + // If Reset() is not called yet, this wait() will block until reset. + if (this->op_total_repeats() < 0) { + RETURN_IF_NOT_OK(wp_.Wait()); + // Clear the status of the wait post + wp_.Clear(); + } else { + // Self-reset to start a new iteration + RETURN_IF_NOT_OK(Reset()); + } } UpdateRepeatAndEpochCounter(); } @@ -240,6 +250,10 @@ Status GeneratorOp::Reset() { // Reset Op state MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(this->Init()); + if (this->op_total_repeats() < 0) { + // Wake up master thread + wp_.Set(); + } return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index 175d1ce680..1d7f2b97f3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -144,6 +144,8 @@ class GeneratorOp : public PipelineOp { py::object generator_; int32_t buffer_id_; + WaitPost wp_; + Status Init(); void Dealloc() noexcept; diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc index 16963278f2..c94b34468a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -22,15 +22,31 @@ #include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" namespace mindspore { namespace dataset { RepeatPass::RepeatPass() - : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} + : is_repeated_(false), + nested_repeats_(0), + num_repeats_(1), + num_epochs_(1), + is_merge_(false), + is_cached_(false), + cache_lookup_(nullptr) {} // Identifies the subtree below this node as being in a repeated path of the tree. Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Create a new stack for eoe operators and push onto our stack of stacks. + std::unique_ptr new_stack = std::make_unique(); + eoe_op_stacks_.push(std::move(new_stack)); + // If we are already repeated, then this is a nested repeat. + if (is_repeated_) { + nested_repeats_++; + } + is_repeated_ = true; + // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { @@ -58,7 +74,9 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modifie // that RepeatOp does. However, epoch control is actually simpler because it can // only exist as the root node so it doesn't need all the nested code. // Create a new stack for eoe operators and push onto our stack of stacks. - + std::unique_ptr new_stack = std::make_unique(); + eoe_op_stacks_.push(std::move(new_stack)); + is_repeated_ = true; // Get the total number of epochs from the EpochCtrlOp parameter num_epochs_ = node->num_repeats(); // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. @@ -85,6 +103,22 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { // Hooks up any identified eoe nodes under this repeat. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking + std::shared_ptr leaf_op = PopFromEOEOpStack(); + + while (leaf_op != nullptr) { + node->AddToEoeList(leaf_op); + leaf_op = PopFromEOEOpStack(); + } + + // At this point, we are done with the save area stack. It's a unique pointer to an empty stack + // at this time, so we can pop it to get rid of it. + op_stack *current_stack = eoe_op_stacks_.top().get(); + if (!current_stack->empty()) { + RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); + } + eoe_op_stacks_.pop(); + // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up // and set its total repeats. It is important that the op is removed from the save area, // because the merge op above us may also take action on it later for a different case when @@ -95,6 +129,18 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { cache_lookup_.reset(); } + // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. + // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. + if (nested_repeats_ > 0) { + AddToEOEOpStack(node); + nested_repeats_--; + } else { + // If we are not nested, or we were the top-most repeat, now we clear the flag + if (nested_repeats_ != 0) { + RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); + } + is_repeated_ = false; + } if (is_cached_) { AddToCachedOpStack(node); } @@ -110,6 +156,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Hooks up any identified eoe nodes under this repeat. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Pop the leaf ops from the save-area stack and add them to the eoe node tracking + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + node->AddToEoeList(leaf_op); + leaf_op = PopFromEOEOpStack(); + } + is_repeated_ = false; node->set_total_repeats(num_repeats_); node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); // We finish the walk of this EpochCtrl's descendent nodes. @@ -138,6 +191,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // If we are in a repeat path, then set our repeated flag + if (is_repeated_) { + // if infinite repeat save ourself in a stack for the repeat operator above us + if (num_repeats_ < 0) { + AddToEOEOpStack(node); + } + } + // If we are under a cache op, then save ourselves to the cached op stack. + if (is_cached_) { + AddToCachedOpStack(node); + } + // Set total repeats and total epochs for the node + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + return Status::OK(); +} // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up // for use with a controlling repeat above it. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { @@ -190,6 +260,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified return Status::OK(); } +// Adds an operator to the eoe operator stack save area +void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { + op_stack *current_stack = eoe_op_stacks_.top().get(); + current_stack->push(dataset_op); +} + +// Pops an operator from the eoe operator stack save area +std::shared_ptr RepeatPass::PopFromEOEOpStack() { + std::shared_ptr top_op = nullptr; + op_stack *current_stack = eoe_op_stacks_.top().get(); + if (current_stack != nullptr && !current_stack->empty()) { + top_op = current_stack->top(); + current_stack->pop(); + } + return top_op; +} + // Adds an operator to the cached operator stack save area void RepeatPass::AddToCachedOpStack(std::shared_ptr dataset_op) { cached_op_stacks_.push(dataset_op); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h index 4345ecc6f6..897f9ab1dd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -98,6 +98,12 @@ class RepeatPass : public NodePass { /// \return Status The error code return Status RunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Special case for GeneratorOp + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up /// for use with a controlling repeat above it. /// \param[in] node The node being visited @@ -106,6 +112,19 @@ class RepeatPass : public NodePass { Status RunOnNode(std::shared_ptr node, bool *modified) override; private: + /// \brief Adds an operator to the eoe operator stack save area + /// \param op - The dataset op to work add to eoe stack + /// \return Status - The error code return + void AddToEOEOpStack(std::shared_ptr dataset_op); + + /// \brief Pops an operator from the eoe operator stack save area + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromEOEOpStack(); + + bool is_repeated_; // T/F if we are processing under a repeat + int32_t nested_repeats_; // A counter for nested repeats + std::stack> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) + /// \brief Adds an operator to the cached operator stack save area /// \param op - The dataset op to work add to cached stack /// \return Status - The error code return