diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc index 1b0890686f..07b51e804b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -91,13 +91,14 @@ Status CacheBase::FetchSamplesToWorkers() { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); // If repeat but the not last repeat, wait for reset. - if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (!IsLastIteration()) { MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; RETURN_IF_NOT_OK(epoch_sync_.Wait()); } else { // We can break out from the loop. break; } + UpdateRepeatAndEpochCounter(); } while (true); // Flow the eof before exit RETURN_IF_NOT_OK( diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index e2a7fc3697..72238b748e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -294,7 +294,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { Status CacheMergeOp::EoeReceived(int32_t worker_id) { // If we are in a repeat path, send the eoe up. // Otherwise ignore it. - if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { + if (op_total_repeats_ > 1) { return DatasetOp::EoeReceived(worker_id); } return Status::OK(); @@ -306,7 +306,7 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) { // getting an eoe. However, the logic demands that all epochs close with an eoe first before eof. // Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class // provides that for us. - if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) { + if (op_total_repeats_ == 1) { MS_LOG(DEBUG) << "Cache merge sending eoe"; RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id)); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index 143c45b2dc..c742d82522 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -85,6 +85,10 @@ Status CacheOp::operator()() { TaskManager::FindMe()->Post(); // Wait for the workers to finish caching the rows. RETURN_IF_NOT_OK(WaitForCachingAllRows()); + // Current repeats and current epochs may have increased when caching all rows with DatasetOp::GetNextInput. + // But they shouldn't be increased because now cache op is starting to act as a leaf and its epoch hasn't started. + op_current_repeats_ = 0; + op_current_epochs_ = 0; RETURN_IF_NOT_OK(FetchSamplesToWorkers()); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc index a01a9cc87f..a3ba23a07e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -87,6 +87,7 @@ Status ConcatOp::operator()() { auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); } + UpdateRepeatAndEpochCounter(); } CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, "Something went wrong, eof count does not match the number of children."); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 51237f58cd..c236698c15 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -42,7 +42,10 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler operator_id_(kInvalidOperatorId), tree_(nullptr), state_(OpState::kDeOpIdle), - op_ctrl_flags_(kDeOpNone), + op_total_repeats_(kInfiniteRepeat), + op_num_repeats_per_epoch_(kInfiniteRepeat), + op_current_repeats_(0), + op_current_epochs_(0), out_connector_(nullptr) { // The operator starts out with an invalid operator id. The only way to // get it out of invalid state is to assign the operator to an execution tree. @@ -234,8 +237,8 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { for (size_t i = 0; i < parent_.size(); i++) { out << "\n Parent[" << i << "] id: " << parent_[i]->id(); } - out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex - << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); + out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_ + << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; if (sampler_) { sampler_->Print(out, show_all); } @@ -264,6 +267,7 @@ Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t wo RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); // Loop until non EOE is received while (buf->eoe()) { + UpdateRepeatAndEpochCounter(); RETURN_IF_NOT_OK(EoeReceived(worker_id)); if (state_ == OpState::kDeOpIdle) { *p_buffer = std::move(buf); @@ -407,5 +411,10 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); return cache_crc; } + +void DatasetOp::UpdateRepeatAndEpochCounter() { + op_current_repeats_++; + if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index 01eb2f93c3..3c83582c9f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -70,13 +70,7 @@ class DatasetOp : public std::enable_shared_from_this { public: static constexpr int32_t kInvalidOperatorId = -1; - - // Operator control flags - enum OpControlFlags { - kDeOpNone = 0, - kDeOpRepeated = 1, // Operator is a node in a repeat path - kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop - }; + static constexpr int32_t kInfiniteRepeat = -1; // Flags that control operator runtime behaviours enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; @@ -238,13 +232,23 @@ class DatasetOp : public std::enable_shared_from_this { /// \return T/F if this is an inlined operator bool inlined() const { return (oc_queue_size_ == 0); } - /// \brief Setter function - /// \return Sets the control flags - void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } + /// \brief Setter function, set the number of total repeats for the operator + void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; } + + /// \brief Setter function, set the number of repeats per epoch for the operator + void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; } - /// \brief Setter function - /// \return Sets the control flags - void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } + /// \brief Getter function + /// \return The number of required repeats for the operator + int32_t op_total_repeats() { return op_total_repeats_; } + + /// \brief Getter function + /// \return The number of required epochs for the operator + int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; } + + /// \brief Getter function + /// \return The number of repeats per epoch for the operator + int32_t op_num_repeats_per_epoch() { return op_num_repeats_per_epoch_; } /// \brief Register the internal worker connectors. No op unless it is a parallel op /// \return Status @@ -350,6 +354,10 @@ class DatasetOp : public std::enable_shared_from_this { /// \return boolean returns true if it's a leaf bool IsLeaf() { return (child_.empty()); } + /// Checks if an operator has reached its last iteration + /// \return boolean returns true if it's last iteration + bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; } + protected: /// \brief Removes a parent operator from this operator /// \notes External callers do not have access to this function @@ -368,6 +376,10 @@ class DatasetOp : public std::enable_shared_from_this { /// \return - Status virtual Status ComputeColMap(); + /// Increase op_current_repeats_ by 1 when one repeat finished. + /// If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1. + void UpdateRepeatAndEpochCounter(); + std::vector> child_; // Child nodes std::vector parent_; // Parent nodes. No ownership std::shared_ptr sampler_; // Some leaf ops might have a sampler @@ -375,7 +387,10 @@ class DatasetOp : public std::enable_shared_from_this { int32_t operator_id_; // Generated id for the node ExecutionTree *tree_; // Back pointer to our tree. OpState state_; // The state of the operator, Running, Idle, Terminated - uint32_t op_ctrl_flags_; // Flags for the operator + int32_t op_total_repeats_; // Required number of repeats for the operator + int32_t op_num_repeats_per_epoch_; // Total number of repeats per epoch for the operator + int32_t op_current_repeats_; // Current number of repeats the operator has handled + int32_t op_current_epochs_; // Current number of epochs the operator has handled std::unique_ptr out_connector_; // Output Connector std::unordered_map column_name_id_map_; // Mapping between col index and col name std::mutex column_name_map_mutex_; // For protecting shared access to the column map diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc index 82885b89c0..83c2681c5e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc @@ -30,7 +30,7 @@ namespace dataset { // The builder "build" method creates the final object. Status EpochCtrlOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_repeats_); + *ptr = std::make_shared(build_num_repeats_); return Status::OK(); } @@ -48,12 +48,12 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); // Then show any custom derived-internal 1-liner info for this op - out << " [epochs: " << max_repeats_ << "]\n"; + out << " [epochs: " << num_repeats_ << "]\n"; } else { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_ + out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ << "\nLeaf Nodes in execution path:"; if (!eoe_ops_.empty()) { for (size_t i = 0; i < eoe_ops_.size(); i++) { @@ -88,24 +88,15 @@ Status EpochCtrlOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t } Status EpochCtrlOp::EoeReceived(int32_t worker_id) { + UpdateRepeatAndEpochCounter(); repeat_count_++; MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_ - << ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_; - - // If we've reached the requested epoch count, then flag the leaf nodes - // to tell them they've got one more epoch to perform. When they reach the end - // of the last epoch, they quit rather than loop again. - if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) { - for (auto &eoe_op : eoe_ops_) { - MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id(); - eoe_op->set_control_flag(kDeOpLastRepeat); - } - } + << ". Max epochs: " << num_repeats_; // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. state_ = OpState::kDeOpIdle; - if (repeat_count_ != max_repeats_) { + if (repeat_count_ != num_repeats_) { for (auto &eoe_op : eoe_ops_) { MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); RETURN_IF_NOT_OK(eoe_op->Reset()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc index 39cdb45b20..b819d31bad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -119,6 +119,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); if (in_buffer->eoe()) { filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); + UpdateRepeatAndEpochCounter(); continue; } else if (in_buffer->eof()) { filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index 89c06f4917..9a6daa62ce 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -233,6 +233,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) { // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work // with Performance Mode design. if (in_buffer->eoe()) { + UpdateRepeatAndEpochCounter(); // Calling base class EoeReceived to forward eoe buffer. RETURN_IF_NOT_OK(EoeReceived(worker_id)); // Fetch next data buffer and map job list diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc index 9c8013497c..b157d20c2b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc @@ -76,6 +76,9 @@ Status ProjectOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t w if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { RETURN_IF_NOT_OK(Project(p_buffer)); } + if ((*p_buffer)->eoe()) { + UpdateRepeatAndEpochCounter(); + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 83cd8b5af8..dd53fec97d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -28,10 +28,10 @@ namespace mindspore { namespace dataset { // Builder constructor. Creates the builder object. -RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {} +RepeatOp::Builder::Builder(int32_t count) : build_num_repeats_(count) {} Status RepeatOp::Builder::SanityCheck() const { - if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) { + if (build_num_repeats_ < kInfiniteRepeat || build_num_repeats_ == 0) { std::string err_msg("Repeat count must be > 0 or -1."); RETURN_STATUS_UNEXPECTED(err_msg); } @@ -41,12 +41,12 @@ Status RepeatOp::Builder::SanityCheck() const { // The builder "build" method creates the final object. Status RepeatOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_repeats_); + *ptr = std::make_shared(build_num_repeats_); return Status::OK(); } // Constructor of the RepeatOp. -RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {} +RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), num_repeats_(count), repeat_count_(0) {} // Destructor RepeatOp::~RepeatOp() {} @@ -59,12 +59,12 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); // Then show any custom derived-internal 1-liner info for this op - out << " [repeats: " << max_repeats_ << "]\n"; + out << " [repeats: " << num_repeats_ << "]\n"; } else { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ + out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ << "\nLeaf Nodes in execution path:"; if (!eoe_ops_.empty()) { for (size_t i = 0; i < eoe_ops_.size(); i++) { @@ -109,22 +109,13 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t wo // Base-class override for handling cases when an eoe is received. Status RepeatOp::EoeReceived(int32_t worker_id) { + UpdateRepeatAndEpochCounter(); + repeat_count_++; MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; - bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); - bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); - // If we've reached the requested repeat count, then flag the eoe nodes - // to tell them they've got one more epoch to perform. When they reach the end - // of the last epoch, they quit rather than loop again. This happens in two cases: - // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, - // 2- We are not repeated - if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { - for (auto &eoe_op : eoe_ops_) { - eoe_op->set_control_flag(kDeOpLastRepeat); - } - } - if (repeat_count_ == max_repeats_) { + + if (repeat_count_ == num_repeats_) { repeat_count_ = 0; state_ = OpState::kDeOpIdle; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index e763e2bcca..bdd4953541 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -26,8 +26,6 @@ namespace mindspore { namespace dataset { class RepeatOp : public PipelineOp { public: - static constexpr int32_t kInfiniteRepeat = -1; - // The nested builder class inside of the RepeatOp is used to help manage all of the arguments // for constructing it. This repeat op is very simple though, so this builder is really just // provided for a consistent look and feel for creators of Dataset operators overall. @@ -47,7 +45,7 @@ class RepeatOp : public PipelineOp { Status Build(std::shared_ptr *); protected: - int32_t build_max_repeats_; + int32_t build_num_repeats_; Status SanityCheck() const; }; @@ -131,13 +129,24 @@ class RepeatOp : public PipelineOp { // @return Name of the current Op std::string Name() const override { return kRepeatOp; } + /// \brief Getter function + /// \return The number of repeats that the user requested + int32_t num_repeats() { return num_repeats_; } + // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes // \param[in] eoe_op The input leaf/eoe operator to add to the list void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } protected: - int32_t max_repeats_; // The number of repeats that the user requested - int32_t repeat_count_; // A counter for the current number of executed repeats + // The number of repeats that the user requested. + // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. + // For example, for repeat1 op in pipeline tfreader -> repeat1(3) -> repeat2(2) -> epoch ctrl(4), + // num_repeats_ = 3, op_total_repeats_ = 24, op_num_repeats_per_epoch_ = 6. + int32_t num_repeats_; + // A counter for the current number of executed repeats. + // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class + // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. + int32_t repeat_count_; std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index 7b374c4075..9ceca1923f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -293,7 +293,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -310,6 +310,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index b06fcdb55d..e37211486a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -120,7 +120,7 @@ Status CifarOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -137,6 +137,7 @@ Status CifarOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 239d323043..8a178a810a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -271,13 +271,14 @@ Status ClueOp::operator()() { std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { finished_reading_dataset_ = true; NotifyToFillIOBlockQueue(); } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; } + UpdateRepeatAndEpochCounter(); } std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index dac2f8f57d..b4df3b9f79 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -167,7 +167,7 @@ Status CocoOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); @@ -184,6 +184,7 @@ Status CocoOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 1bf887458f..ba5bbcde0b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -472,13 +472,14 @@ Status CsvOp::operator()() { std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { finished_reading_dataset_ = true; NotifyToFillIOBlockQueue(); } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; } + UpdateRepeatAndEpochCounter(); } std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 4af3042861..a259c6e574 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -218,7 +218,7 @@ Status GeneratorOp::operator()() { MS_LOG(DEBUG) << "Generator operator sends out EOE."; std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { // If last repeat or not repeated, push out EOF and exit master loop MS_LOG(DEBUG) << "Generator operator sends out EOF."; std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); @@ -233,6 +233,7 @@ Status GeneratorOp::operator()() { // Clear the status of the wait post wp_.Clear(); } + UpdateRepeatAndEpochCounter(); } } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 9a3bbccdcf..2e7fd90bf4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -151,7 +151,7 @@ Status ImageFolderOp::operator()() { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); @@ -168,6 +168,7 @@ Status ImageFolderOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 7982c63d10..a46087f415 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -112,7 +112,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -129,6 +129,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index 2af4259944..7cc2757a1a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -380,7 +380,7 @@ Status MindRecordOp::operator()() { RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -398,6 +398,7 @@ Status MindRecordOp::operator()() { RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); shard_reader_wait_post_.Clear(); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index b3c52be60e..5ef09b5f7e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -111,7 +111,7 @@ Status MnistOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -128,6 +128,7 @@ Status MnistOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 8e09cc2b6c..9199a83298 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -221,7 +221,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { all_out_.Wait(); // If we are not in a repeat loop, or that was the last repeat already, then setup our exit // condition from the master loop. - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { *quitting = true; } @@ -231,6 +231,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { if (last_guy_in) { MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " << eoe_worker_id_; + UpdateRepeatAndEpochCounter(); // Prepare for sync all_out_.Clear(); // Always flow eoe at the end diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 104d7919ce..e1f8dab612 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -421,13 +421,14 @@ Status TextFileOp::operator()() { std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { finished_reading_dataset_ = true; NotifyToFillIOBlockQueue(); } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; } + UpdateRepeatAndEpochCounter(); } std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index f0d40a9ba8..dbd44ec447 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -310,13 +310,14 @@ Status TFReaderOp::operator()() { std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { finished_reading_dataset_ = true; NotifyToFillIOBlockQueue(); } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; } + UpdateRepeatAndEpochCounter(); } std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index bb48d5e418..d3b6ea00f9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -145,7 +145,7 @@ Status VOCOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); @@ -162,6 +162,7 @@ Status VOCOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc index dfd4f254e0..3a8344550f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc @@ -84,6 +84,7 @@ Status TakeOp::operator()() { // Loop until non EOE is received if (buf->eoe()) { + UpdateRepeatAndEpochCounter(); take_count_ = 0; RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc index cf8852bf44..aac0eaa2e9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -25,18 +25,44 @@ namespace mindspore { namespace dataset { -RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} +RepeatPass::RepeatPass() + : is_repeated_(false), + nested_repeats_(0), + num_repeats_(1), + num_epochs_(1), + is_merge_(false), + is_cached_(false), + cache_lookup_(nullptr) {} // Identifies the subtree below this node as being in a repeated path of the tree. Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { // Create a new stack for eoe operators and push onto our stack of stacks. - std::unique_ptr new_stack = std::make_unique(); + std::unique_ptr new_stack = std::make_unique(); eoe_op_stacks_.push(std::move(new_stack)); // If we are already repeated, then this is a nested repeat. if (is_repeated_) { nested_repeats_++; } is_repeated_ = true; + + // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. + // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. + if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { + num_repeats_ = -num_repeats_; + } + // This RepeatOp and its descendent nodes should be repeated for another num_repeats() times. + // + // Consider this example: + // tfreader --> map --> repeat(2) --> epoch ctrl(3) + // num_repeats_ is originally 3, after this repeat(2), num_repeats_ becomes 6 (2*3), + // meaning repeat op should be set to read 6 times (2*3), do does map op and tfreader op. + // + // Another example: + // tfreader --> repeat1(3) --> map --> repeat2(2) --> epoch ctrl(4) + // num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4), + // meaning repeat2 and map op should be set to read 8 times (2*4). + // Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times. + num_repeats_ *= node->num_repeats(); return Status::OK(); } @@ -46,9 +72,16 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modifie // that RepeatOp does. However, epoch control is actually simpler because it can // only exist as the root node so it doesn't need all the nested code. // Create a new stack for eoe operators and push onto our stack of stacks. - std::unique_ptr new_stack = std::make_unique(); + std::unique_ptr new_stack = std::make_unique(); eoe_op_stacks_.push(std::move(new_stack)); is_repeated_ = true; + // Get the total number of epochs from the EpochCtrlOp parameter + num_epochs_ = node->num_repeats(); + // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. + // For example: tfreader --> epoch ctrl(3) + // num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3), + // meaning epoch ctrl op should be set to read 3 times (1*3), so does tfreader op. + num_repeats_ *= num_epochs_; return Status::OK(); } @@ -59,6 +92,13 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modifi return Status::OK(); } +// Identifies the subtree below this node as being cached +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that we're under a merge op + is_cached_ = true; + return Status::OK(); +} + // Hooks up any identified eoe nodes under this repeat. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking @@ -71,7 +111,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // At this point, we are done with the save area stack. It's a unique pointer to an empty stack // at this time, so we can pop it to get rid of it. - eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); + op_stack *current_stack = eoe_op_stacks_.top().get(); if (!current_stack->empty()) { RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); } @@ -82,14 +122,14 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // from the save area, because the merge op above us may also take action on it later for a different // case when there is no repeat in the merge leg. if (is_merge_ && cache_lookup_) { - cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); + cache_lookup_->set_total_repeats(num_repeats_); + cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); node->AddToEoeList(std::move(cache_lookup_)); } // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. if (nested_repeats_ > 0) { - node->set_control_flag(DatasetOp::kDeOpRepeated); AddToEOEOpStack(node); nested_repeats_--; } else { @@ -99,7 +139,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { } is_repeated_ = false; } - + if (is_cached_) { + AddToCachedOpStack(node); + } + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + // We finish the walk of this RepeatOp's descendent nodes. + // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. + // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, + // so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. + num_repeats_ /= node->num_repeats(); return Status::OK(); } @@ -112,13 +161,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) leaf_op = PopFromEOEOpStack(); } is_repeated_ = false; + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + // We finish the walk of this EpochCtrl's descendent nodes. + num_repeats_ /= node->num_repeats(); return Status::OK(); } // CacheOp removes previous leaf ops and replaces them with itself Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + is_cached_ = false; if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); // if we are a cache within a repeat path of the tree, then there will be // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the // repeat or epoch ctrl operators can work with them for repeat activity during runtime. @@ -130,13 +183,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // the repeating behaviours shall be invoked against the cache op. std::shared_ptr leaf_op = PopFromEOEOpStack(); while (leaf_op != nullptr) { - leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); - leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); leaf_op = PopFromEOEOpStack(); } AddToEOEOpStack(std::static_pointer_cast(node)); + + // adjust the total epochs and total repeats for ops under this cache op + std::shared_ptr cached_op = PopFromCachedOpStack(); + while (cached_op != nullptr) { + int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; + cached_op->set_total_repeats(cached_op_total_repeats); + // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 + cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); + cached_op = PopFromCachedOpStack(); + } } + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); return Status::OK(); } @@ -145,13 +208,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // If we are in a repeat path, then set our repeated flag if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - // if we are a leaf node then save ourself in a stack for the repeat operator above us if (node->IsLeaf()) { AddToEOEOpStack(node); } } + if (is_cached_) { + AddToCachedOpStack(node); + } + // Set total repeats and total epochs for the node + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); return Status::OK(); } @@ -159,13 +226,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Setting the flag is needed since we didn't call the base class DatasetOp version if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); // If there was not any repeat in the merge cache miss leg, then the cache_lookup // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack if (cache_lookup_) { + cache_lookup_->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); AddToEOEOpStack(std::move(cache_lookup_)); } } + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used is_merge_ = false; return Status::OK(); @@ -178,13 +248,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); } - // If we are in a repeat path already, then there must be a repeat above the merge op - // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. - if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - // Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that. - } - // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. @@ -197,19 +260,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified // Adds an operator to the eoe operator stack save area void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { - eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); + op_stack *current_stack = eoe_op_stacks_.top().get(); current_stack->push(dataset_op); } // Pops an operator from the eoe operator stack save area std::shared_ptr RepeatPass::PopFromEOEOpStack() { std::shared_ptr top_op = nullptr; - eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); + op_stack *current_stack = eoe_op_stacks_.top().get(); if (current_stack != nullptr && !current_stack->empty()) { top_op = current_stack->top(); current_stack->pop(); } return top_op; } + +// Adds an operator to the cached operator stack save area +void RepeatPass::AddToCachedOpStack(std::shared_ptr dataset_op) { cached_op_stacks_.push(dataset_op); } + +// Pops an operator from the cached operator stack save area +std::shared_ptr RepeatPass::PopFromCachedOpStack() { + std::shared_ptr top_op = nullptr; + if (!cached_op_stacks_.empty()) { + top_op = cached_op_stacks_.top(); + cached_op_stacks_.pop(); + } + return top_op; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h index 67a243f44c..1e865eadac 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -30,7 +30,7 @@ namespace dataset { /// to the eoe-producing (typically leaf) nodes underneath it. class RepeatPass : public NodePass { public: - using eoe_op_stack = std::stack>; + using op_stack = std::stack>; /// \brief Constructor RepeatPass(); @@ -56,6 +56,12 @@ class RepeatPass : public NodePass { /// \return Status The error code return Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Identifies the subtree below this node as being cached + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Hooks up any identified eoe nodes under this repeat. /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all @@ -103,11 +109,24 @@ class RepeatPass : public NodePass { /// \return shared_ptr to the popped operator std::shared_ptr PopFromEOEOpStack(); - bool is_repeated_; // T/F if we are processing under a repeat - bool is_merge_; // T/F if we are processing under a cache merge op - int32_t nested_repeats_; // A counter for nested repeats - std::stack> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) - std::shared_ptr cache_lookup_; // A save area for a cache lookup op + /// \brief Adds an operator to the cached operator stack save area + /// \param op - The dataset op to work add to cached stack + /// \return Status - The error code return + void AddToCachedOpStack(std::shared_ptr dataset_op); + + /// \brief Pops an operator from the cached operator stack save area + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromCachedOpStack(); + + bool is_repeated_; // T/F if we are processing under a repeat + bool is_merge_; // T/F if we are processing under a cache merge op + bool is_cached_; // T/F is we are processing under a cache op + int32_t nested_repeats_; // A counter for nested repeats + int32_t num_repeats_; // A multiplier to the total number of repeats + int32_t num_epochs_; // To save the total number of epochs + std::stack> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) + op_stack cached_op_stacks_; // A save area for ops under a cache op + std::shared_ptr cache_lookup_; // A save area for a cache lookup op }; } // namespace dataset } // namespace mindspore diff --git a/tests/ut/python/dataset/test_epoch_ctrl.py b/tests/ut/python/dataset/test_epoch_ctrl.py index 20f7b1ef65..3a5ddb3b8c 100644 --- a/tests/ut/python/dataset/test_epoch_ctrl.py +++ b/tests/ut/python/dataset/test_epoch_ctrl.py @@ -565,6 +565,99 @@ def test_generator_tuple_repeat_repeat_3(): # rely on garbage collector to destroy iter1 + +def test_generator_tuple_infinite_repeat_repeat_1(): + """ + test generator tuple infinite repeat repeat 1 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat() + data1 = data1.repeat(3) + iter1 = data1.create_tuple_iterator(num_epochs=11) + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + +def test_generator_tuple_infinite_repeat_repeat_2(): + """ + test generator tuple infinite repeat repeat 2 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(3) + data1 = data1.repeat() + iter1 = data1.create_tuple_iterator(num_epochs=11) + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + +def test_generator_tuple_infinite_repeat_repeat_3(): + """ + test generator tuple infinite repeat repeat 3 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat() + data1 = data1.repeat() + iter1 = data1.create_tuple_iterator(num_epochs=11) + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + +def test_generator_tuple_infinite_repeat_repeat_4(): + """ + test generator tuple infinite repeat repeat 4 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat() + data1 = data1.repeat() + iter1 = data1.create_tuple_iterator() + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + def test_generator_reusedataset(): """ test generator reusedataset