!3974 Eliminate the pause for each cycle of the Repeat

Merge pull request !3974 from lixiachen/repeat_task2
pull/3974/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0d486adef1

@ -218,16 +218,6 @@ Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t ba
return Status::OK(); return Status::OK();
} }
Status BucketBatchByLengthOp::Reset() {
batch_count_ = 0;
for (int i = 0; i < buckets_.size(); i++) {
buckets_[i] = std::make_unique<TensorQTable>();
}
return Status::OK();
}
// Computing the assignment of the column name map and check compute input columns. // Computing the assignment of the column name map and check compute input columns.
Status BucketBatchByLengthOp::ComputeColMap() { Status BucketBatchByLengthOp::ComputeColMap() {
RETURN_IF_NOT_OK(DatasetOp::ComputeColMap()); RETURN_IF_NOT_OK(DatasetOp::ComputeColMap());

@ -126,10 +126,6 @@ class BucketBatchByLengthOp : public PipelineOp {
// @return Status - The error code returned // @return Status - The error code returned
Status operator()() override; Status operator()() override;
// Function that is called by ResetOp at the end of every epoch
// @return Status - The error code returned
Status Reset() override;
private: private:
Status ObtainElementLength(int32_t *out_element_length, TensorRow element); Status ObtainElementLength(int32_t *out_element_length, TensorRow element);

@ -42,8 +42,7 @@ Status CacheBase::Reset() {
RETURN_IF_NOT_OK(sampler_->ResetSampler()); RETURN_IF_NOT_OK(sampler_->ResetSampler());
} }
// Wake up the workers to get them going again in a new epoch // Wake up the workers to get them going again in a new epoch
MS_LOG(DEBUG) << Name() << " resetting."; MS_LOG(DEBUG) << Name() << " performing a self-reset.";
epoch_sync_.Set();
return Status::OK(); return Status::OK();
} }
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
@ -72,7 +71,6 @@ Status CacheBase::FetchSamplesToWorkers() {
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
// to the WorkerEntry. // to the WorkerEntry.
do { do {
epoch_sync_.Clear();
if (AllowCacheMiss() && wait_cnt > 0) { if (AllowCacheMiss() && wait_cnt > 0) {
MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_; << " Total number of rows : " << row_cnt_;
@ -112,11 +110,17 @@ Status CacheBase::FetchSamplesToWorkers() {
// If repeat but the not last repeat, wait for reset. // If repeat but the not last repeat, wait for reset.
if (!IsLastIteration()) { if (!IsLastIteration()) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
RETURN_IF_NOT_OK(epoch_sync_.Wait());
} else { } else {
// We can break out from the loop. // We can break out from the loop.
break; break;
} }
if (epoch_sync_flag_) {
// If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
// the current epoch.
RETURN_IF_NOT_OK(WaitForWorkers());
}
// If not the last repeat, self-reset and go to loop again.
if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset());
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
} while (true); } while (true);
// Flow the eof before exit // Flow the eof before exit
@ -142,7 +146,13 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
std::unique_ptr<IOBlock> blk; std::unique_ptr<IOBlock> blk;
do { do {
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
if (blk->eof()) { if (blk->wait()) {
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
// The last guy who comes to this sync point should reset the counter and wake up the master thread.
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
} else if (blk->eof()) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
} else if (blk->eoe()) { } else if (blk->eoe()) {
if (AllowCacheMiss()) { if (AllowCacheMiss()) {
@ -186,7 +196,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
} }
Status CacheBase::RegisterResources() { Status CacheBase::RegisterResources() {
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks()));

@ -26,7 +26,6 @@
#include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/queue.h"
@ -88,7 +87,6 @@ class CacheBase : public ParallelOp {
int64_t row_cnt_; int64_t row_cnt_;
std::atomic<int64_t> num_cache_miss_; std::atomic<int64_t> num_cache_miss_;
std::shared_ptr<CacheClient> cache_client_; std::shared_ptr<CacheClient> cache_client_;
WaitPost epoch_sync_;
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
Connector<std::vector<row_id_type>> keys_miss_; Connector<std::vector<row_id_type>> keys_miss_;
QueueMap<row_id_type, TensorRow> prefetch_; QueueMap<row_id_type, TensorRow> prefetch_;
@ -110,7 +108,6 @@ class CacheBase : public ParallelOp {
private: private:
constexpr static int32_t connector_capacity_ = 1024; constexpr static int32_t connector_capacity_ = 1024;
int32_t prefetch_size_; int32_t prefetch_size_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; QueueList<std::unique_ptr<IOBlock>> prefetch_queues_;
std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_; std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_;

@ -434,6 +434,7 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
void DatasetOp::UpdateRepeatAndEpochCounter() { void DatasetOp::UpdateRepeatAndEpochCounter() {
op_current_repeats_++; op_current_repeats_++;
if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++;
MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_;
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -51,15 +51,7 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all); PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff // Then show any custom derived-internal stuff
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_;
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << "\n Operator: " << eoe_ops_[i]->id();
}
} else {
out << " None.";
}
out << "\n\n"; out << "\n\n";
} }
} }
@ -94,13 +86,6 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
state_ = OpState::kDeOpIdle; state_ = OpState::kDeOpIdle;
if (repeat_count_ != num_repeats_) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
}
return Status::OK(); return Status::OK();
} }

@ -123,7 +123,6 @@ Status FilterOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id));
if (in_buffer->eoe()) { if (in_buffer->eoe()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe));
UpdateRepeatAndEpochCounter();
continue; continue;
} else if (in_buffer->eof()) { } else if (in_buffer->eof()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof));
@ -200,6 +199,7 @@ Status FilterOp::Collector() {
RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair));
if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial ||
in_pair.second == filterCtrl::kFilterEoe) { in_pair.second == filterCtrl::kFilterEoe) {
if (in_pair.second == filterCtrl::kFilterEoe) UpdateRepeatAndEpochCounter();
uint32_t out_task_id = out_id_cnt % num_workers_; uint32_t out_task_id = out_id_cnt % num_workers_;
RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first))); RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first)));
out_id_cnt++; out_id_cnt++;

@ -228,12 +228,6 @@ class MapOp : public ParallelOp {
// Indices of the columns to process. // Indices of the columns to process.
std::vector<size_t> to_process_indices_; std::vector<size_t> to_process_indices_;
// Wait post used to perform the pausing logic in MapOp
WaitPost wait_for_workers_post_;
// Count number of workers that have signaled master
std::atomic_int num_workers_paused_;
// Private function for worker/thread to loop continuously. It comprises the main // Private function for worker/thread to loop continuously. It comprises the main
// logic of MapOp: getting the data from previous Op, validating user specified column names, // logic of MapOp: getting the data from previous Op, validating user specified column names,
// applying a list of TensorOps to each of the data, process the results and then // applying a list of TensorOps to each of the data, process the results and then

@ -31,7 +31,9 @@ ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shar
num_workers_(num_workers), num_workers_(num_workers),
num_producers_(num_workers), num_producers_(num_workers),
worker_connector_size_(1), worker_connector_size_(1),
worker_connector_(nullptr) {} worker_connector_(nullptr),
num_workers_paused_(0),
epoch_sync_flag_(false) {}
// Creates the internal worker connector for the parallel op if the derived class wants to use it // Creates the internal worker connector for the parallel op if the derived class wants to use it
Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) {
@ -82,5 +84,15 @@ Status ParallelOp::RegisterWorkerConnectors() {
} }
return Status::OK(); return Status::OK();
} }
Status ParallelOp::WaitForWorkers() {
num_workers_paused_ = 0;
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(io_block_queues_[i]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagWait)));
}
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
wait_for_workers_post_.Clear();
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
@ -117,10 +118,27 @@ class ParallelOp : public DatasetOp {
// @return Status - The error code return // @return Status - The error code return
virtual Status WorkerEntry(int32_t workerId) = 0; virtual Status WorkerEntry(int32_t workerId) = 0;
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
/// The expected behavior is this, when this function is invoked, this function will block until all the workers
/// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
/// They would automatically wait on the QueueList when they are done.
/// \return Status
Status WaitForWorkers() override;
// Wait post used to perform the pausing logic
WaitPost wait_for_workers_post_;
// Count number of workers that have signaled master
std::atomic_int num_workers_paused_;
// Whether or not to sync worker threads at the end of each epoch
bool epoch_sync_flag_;
int32_t num_workers_; // The number of worker threads int32_t num_workers_; // The number of worker threads
int32_t num_producers_; // The number of threads pushing to the out_connector_ int32_t num_producers_; // The number of threads pushing to the out_connector_
int32_t worker_connector_size_; int32_t worker_connector_size_;
std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads
QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -62,15 +62,7 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all); PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff // Then show any custom derived-internal stuff
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_;
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << "\n Operator: " << eoe_ops_[i]->id();
}
} else {
out << " None.";
}
out << "\n\n"; out << "\n\n";
} }
} }
@ -108,7 +100,6 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
// Base-class override for handling cases when an eoe is received. // Base-class override for handling cases when an eoe is received.
Status RepeatOp::EoeReceived(int32_t worker_id) { Status RepeatOp::EoeReceived(int32_t worker_id) {
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
repeat_count_++; repeat_count_++;
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ MS_LOG(DEBUG) << "Repeat operator (" << operator_id_
<< ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; << ") end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
@ -116,15 +107,9 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
if (repeat_count_ == num_repeats_) { if (repeat_count_ == num_repeats_) {
repeat_count_ = 0; repeat_count_ = 0;
state_ = OpState::kDeOpIdle; state_ = OpState::kDeOpIdle;
return Status::OK(); } else {
} state_ = OpState::kDeOpRunning;
// Invoke a reset against the eoe nodes only.
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
} }
return Status::OK(); return Status::OK();
} }
@ -153,19 +138,6 @@ int32_t RepeatOp::num_consumers() const {
} }
} }
// Drive reset actions if needed
Status RepeatOp::Reset() {
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset.";
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
state_ = OpState::kDeOpRunning;
return Status::OK();
}
int32_t RepeatOp::num_producers() const { int32_t RepeatOp::num_producers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";

@ -101,10 +101,6 @@ class RepeatOp : public PipelineOp {
// @param worker_id - The worker id // @param worker_id - The worker id
Status EofReceived(int32_t worker_id) override; Status EofReceived(int32_t worker_id) override;
/// \brief reset Op
/// \@return Status - The error code return
Status Reset() override;
// Base-class override. Return the number of workers in the first parent. // Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id // @param workerId - The worker id
int32_t num_consumers() const override; int32_t num_consumers() const override;
@ -133,10 +129,6 @@ class RepeatOp : public PipelineOp {
/// \return The number of repeats that the user requested /// \return The number of repeats that the user requested
int32_t num_repeats() { return num_repeats_; } int32_t num_repeats() { return num_repeats_; }
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
protected: protected:
// The number of repeats that the user requested. // The number of repeats that the user requested.
// Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class.
@ -147,7 +139,6 @@ class RepeatOp : public PipelineOp {
// Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class
// because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats.
int32_t repeat_count_; int32_t repeat_count_;
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -161,11 +161,19 @@ Status AlbumOp::operator()() {
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
} }
return Status::OK(); return Status::OK();
} else { // not the last repeat. Sleep master thread, wait for the wake-up from reset } else { // not the last repeat.
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks }
wp_.Clear();
if (epoch_sync_flag_) {
// If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
// the current epoch.
RETURN_IF_NOT_OK(WaitForWorkers());
}
// If not the last repeat, self-reset and go to loop again.
if (!IsLastIteration()) {
RETURN_IF_NOT_OK(Reset());
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
} }
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
@ -180,7 +188,13 @@ Status AlbumOp::WorkerEntry(int32_t worker_id) {
std::unique_ptr<IOBlock> io_block; std::unique_ptr<IOBlock> io_block;
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
while (io_block != nullptr) { while (io_block != nullptr) {
if (io_block->eoe() == true) { if (io_block->wait() == true) {
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
// The last guy who comes to this sync point should reset the counter and wake up the master thread.
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
} else if (io_block->eoe() == true) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
buffer_id = worker_id; buffer_id = worker_id;
} else if (io_block->eof() == true) { } else if (io_block->eof() == true) {
@ -468,9 +482,9 @@ void AlbumOp::Print(std::ostream &out, bool show_all) const {
// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status AlbumOp::Reset() { Status AlbumOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
RETURN_IF_NOT_OK(sampler_->ResetSampler()); RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();
} }
@ -486,7 +500,7 @@ Status AlbumOp::LaunchThreadsAndInitOp() {
} }
// registers QueueList and individual Queues for interrupt services // registers QueueList and individual Queues for interrupt services
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
// launch main workers that load DataBuffers by reading all images // launch main workers that load DataBuffers by reading all images
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&AlbumOp::WorkerEntry, this, std::placeholders::_1))); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&AlbumOp::WorkerEntry, this, std::placeholders::_1)));
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();

@ -30,7 +30,6 @@
#include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/queue.h"
@ -289,9 +288,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
int64_t buf_cnt_; int64_t buf_cnt_;
int64_t sampler_ind_; int64_t sampler_ind_;
int64_t dirname_offset_; int64_t dirname_offset_;
WaitPost wp_;
std::vector<std::string> image_rows_; std::vector<std::string> image_rows_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -94,7 +94,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() {
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1)));
@ -311,11 +311,19 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
} }
return Status::OK(); return Status::OK();
} else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset } else { // not the last repeat.
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks }
wp_.Clear();
if (epoch_sync_flag_) {
// If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
// the current epoch.
RETURN_IF_NOT_OK(WaitForWorkers());
}
// If not the last repeat, self-reset and go to loop again.
if (!IsLastIteration()) {
RETURN_IF_NOT_OK(Reset());
RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer));
} }
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
@ -328,7 +336,13 @@ Status CelebAOp::WorkerEntry(int32_t worker_id) {
std::unique_ptr<IOBlock> io_block; std::unique_ptr<IOBlock> io_block;
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
while (io_block != nullptr) { while (io_block != nullptr) {
if (io_block->eoe() == true) { if (io_block->wait() == true) {
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
// The last guy who comes to this sync point should reset the counter and wake up the master thread.
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
} else if (io_block->eoe() == true) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
buffer_id = worker_id; buffer_id = worker_id;
} else if (io_block->eof() == true) { } else if (io_block->eof() == true) {
@ -410,8 +424,8 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const {
// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status CelebAOp::Reset() { Status CelebAOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
RETURN_IF_NOT_OK(sampler_->ResetSampler()); RETURN_IF_NOT_OK(sampler_->ResetSampler());
wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();
} }

@ -229,8 +229,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_; std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_;
int64_t num_rows_in_attr_file_; // rows number specified in attr file int64_t num_rows_in_attr_file_; // rows number specified in attr file
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
WaitPost wp_;
std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_; std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_;
std::string usage_; std::string usage_;
std::ifstream partition_file_; std::ifstream partition_file_;

@ -140,11 +140,19 @@ Status CifarOp::operator()() {
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
} }
return Status::OK(); return Status::OK();
} else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset } else { // not the last repeat.
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks }
wp_.Clear();
if (epoch_sync_flag_) {
// If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
// the current epoch.
RETURN_IF_NOT_OK(WaitForWorkers());
}
// If not the last repeat, self-reset and go to loop again.
if (!IsLastIteration()) {
RETURN_IF_NOT_OK(Reset());
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
} }
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
@ -156,7 +164,7 @@ Status CifarOp::LaunchThreadsAndInitOp() {
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
} }
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1)));
@ -175,7 +183,13 @@ Status CifarOp::WorkerEntry(int32_t worker_id) {
std::unique_ptr<IOBlock> io_block; std::unique_ptr<IOBlock> io_block;
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
while (io_block != nullptr) { while (io_block != nullptr) {
if (io_block->eoe() == true) { if (io_block->wait() == true) {
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
// The last guy who comes to this sync point should reset the counter and wake up the master thread.
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
} else if (io_block->eoe() == true) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
buffer_id = worker_id; buffer_id = worker_id;
} else if (io_block->eof() == true) { } else if (io_block->eof() == true) {
@ -243,9 +257,9 @@ void CifarOp::Print(std::ostream &out, bool show_all) const {
// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status CifarOp::Reset() { Status CifarOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
RETURN_IF_NOT_OK(sampler_->ResetSampler()); RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();
} }

@ -26,7 +26,6 @@
#include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/queue.h"
@ -233,11 +232,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
std::string folder_path_; std::string folder_path_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
int64_t row_cnt_; int64_t row_cnt_;
int64_t buf_cnt_; int64_t buf_cnt_;
const std::string usage_; // can only be either "train" or "test" const std::string usage_; // can only be either "train" or "test"
WaitPost wp_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_; std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_;
std::vector<std::string> cifar_files_; std::vector<std::string> cifar_files_;
std::vector<std::pair<std::shared_ptr<Tensor>, std::vector<uint32_t>>> cifar_image_label_pairs_; std::vector<std::pair<std::shared_ptr<Tensor>, std::vector<uint32_t>>> cifar_image_label_pairs_;

@ -119,6 +119,7 @@ Status ClueOp::Init() {
} }
Status ClueOp::Reset() { Status ClueOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true; load_jagged_connector_ = true;
load_io_block_queue_ = true; load_io_block_queue_ = true;
@ -274,6 +275,8 @@ Status ClueOp::operator()() {
} else { } else {
jagged_buffer_connector_->DoReset(); jagged_buffer_connector_->DoReset();
buffer_id = 0; buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
} }
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
} }

@ -25,7 +25,6 @@
#include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

@ -185,8 +185,16 @@ Status CocoOp::operator()() {
} else { } else {
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); }
wp_.Clear();
if (epoch_sync_flag_) {
// If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
// the current epoch.
RETURN_IF_NOT_OK(WaitForWorkers());
}
// If not the last repeat, self-reset and go to loop again.
if (!IsLastIteration()) {
RETURN_IF_NOT_OK(Reset());
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
} }
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
@ -208,9 +216,9 @@ void CocoOp::Print(std::ostream &out, bool show_all) const {
} }
Status CocoOp::Reset() { Status CocoOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
RETURN_IF_NOT_OK(sampler_->ResetSampler()); RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set();
return Status::OK(); return Status::OK();
} }
@ -377,7 +385,13 @@ Status CocoOp::WorkerEntry(int32_t worker_id) {
std::unique_ptr<IOBlock> io_block; std::unique_ptr<IOBlock> io_block;
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
while (io_block != nullptr) { while (io_block != nullptr) {
if (io_block->eoe() == true) { if (io_block->wait() == true) {
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
// The last guy who comes to this sync point should reset the counter and wake up the master thread.
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
} else if (io_block->eoe() == true) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
buffer_id = worker_id; buffer_id = worker_id;
} else if (io_block->eof() == true) { } else if (io_block->eof() == true) {
@ -609,7 +623,7 @@ Status CocoOp::LaunchThreadsAndInitOp() {
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
} }
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1)));
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(this->ParseAnnotationIds()); RETURN_IF_NOT_OK(this->ParseAnnotationIds());

@ -27,7 +27,6 @@
#include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
@ -327,10 +326,8 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
std::shared_ptr<Sampler> sampler_; std::shared_ptr<Sampler> sampler_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
WaitPost wp_;
std::vector<std::string> image_ids_; std::vector<std::string> image_ids_;
std::map<int32_t, std::string> image_index_; std::map<int32_t, std::string> image_index_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
std::vector<std::pair<std::string, std::vector<int32_t>>> label_index_; std::vector<std::pair<std::string, std::vector<int32_t>>> label_index_;
std::map<std::string, CoordinateRow> coordinate_map_; std::map<std::string, CoordinateRow> coordinate_map_;
std::map<std::string, std::vector<uint32_t>> simple_item_map_; std::map<std::string, std::vector<uint32_t>> simple_item_map_;

@ -479,6 +479,7 @@ Status CsvOp::CsvParser::InitCsvParser() {
} }
Status CsvOp::Reset() { Status CsvOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true; load_jagged_connector_ = true;
load_io_block_queue_ = true; load_io_block_queue_ = true;
@ -572,6 +573,8 @@ Status CsvOp::operator()() {
} else { } else {
jagged_buffer_connector_->DoReset(); jagged_buffer_connector_->DoReset();
buffer_id = 0; buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
} }
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
} }

@ -186,7 +186,6 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) {
Status GeneratorOp::operator()() { Status GeneratorOp::operator()() {
// Handshake with TaskManager to synchronize thread creation // Handshake with TaskManager to synchronize thread creation
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks()));
std::unique_ptr<DataBuffer> fetched_buffer; std::unique_ptr<DataBuffer> fetched_buffer;
bool eof = false; bool eof = false;
while (!eof) { while (!eof) {
@ -228,12 +227,8 @@ Status GeneratorOp::operator()() {
MS_LOG(DEBUG) << "Generator operator main execution loop complete."; MS_LOG(DEBUG) << "Generator operator main execution loop complete.";
eof = true; eof = true;
} else { } else {
// Waiting for repeatOp to start new epoch // Self-reset to start a new iteration
// If Reset() is called first by repeat op, this wait() will return right away. RETURN_IF_NOT_OK(Reset());
// If Reset() is not called yet, this wait() will block until reset.
RETURN_IF_NOT_OK(wp_.Wait());
// Clear the status of the wait post
wp_.Clear();
} }
UpdateRepeatAndEpochCounter(); UpdateRepeatAndEpochCounter();
} }
@ -243,9 +238,8 @@ Status GeneratorOp::operator()() {
Status GeneratorOp::Reset() { Status GeneratorOp::Reset() {
// Reset Op state // Reset Op state
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
RETURN_IF_NOT_OK(this->Init()); RETURN_IF_NOT_OK(this->Init());
// Wake up master thread
wp_.Set();
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
} }

@ -144,8 +144,6 @@ class GeneratorOp : public PipelineOp {
py::object generator_; py::object generator_;
int32_t buffer_id_; int32_t buffer_id_;
WaitPost wp_;
Status Init(); Status Init();
void Dealloc() noexcept; void Dealloc() noexcept;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save