Maintain epoch/repeat count for ops

pull/3346/head
Lixia Chen 5 years ago
parent 11b3c91156
commit ac85b77b76

@ -91,13 +91,14 @@ Status CacheBase::FetchSamplesToWorkers() {
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
// If repeat but the not last repeat, wait for reset.
if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (!IsLastIteration()) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
RETURN_IF_NOT_OK(epoch_sync_.Wait());
} else {
// We can break out from the loop.
break;
}
UpdateRepeatAndEpochCounter();
} while (true);
// Flow the eof before exit
RETURN_IF_NOT_OK(

@ -294,7 +294,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if (BitTest(op_ctrl_flags_, kDeOpRepeated)) {
if (op_total_repeats_ > 1) {
return DatasetOp::EoeReceived(worker_id);
}
return Status::OK();
@ -306,7 +306,7 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) {
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) {
if (op_total_repeats_ == 1) {
MS_LOG(DEBUG) << "Cache merge sending eoe";
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
}

@ -85,6 +85,10 @@ Status CacheOp::operator()() {
TaskManager::FindMe()->Post();
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK(WaitForCachingAllRows());
// Current repeats and current epochs may have increased when caching all rows with DatasetOp::GetNextInput.
// But they shouldn't be increased because now cache op is starting to act as a leaf and its epoch hasn't started.
op_current_repeats_ = 0;
op_current_epochs_ = 0;
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
return Status::OK();
}

@ -87,6 +87,7 @@ Status ConcatOp::operator()() {
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
}
UpdateRepeatAndEpochCounter();
}
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
"Something went wrong, eof count does not match the number of children.");

@ -42,7 +42,10 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler
operator_id_(kInvalidOperatorId),
tree_(nullptr),
state_(OpState::kDeOpIdle),
op_ctrl_flags_(kDeOpNone),
op_total_repeats_(kInfiniteRepeat),
op_num_repeats_per_epoch_(kInfiniteRepeat),
op_current_repeats_(0),
op_current_epochs_(0),
out_connector_(nullptr) {
// The operator starts out with an invalid operator id. The only way to
// get it out of invalid state is to assign the operator to an execution tree.
@ -234,8 +237,8 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const {
for (size_t i = 0; i < parent_.size(); i++) {
out << "\n Parent[" << i << "] id: " << parent_[i]->id();
}
out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex
<< std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' ');
out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_
<< "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_;
if (sampler_) {
sampler_->Print(out, show_all);
}
@ -264,6 +267,7 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
// Loop until non EOE is received
while (buf->eoe()) {
UpdateRepeatAndEpochCounter();
RETURN_IF_NOT_OK(EoeReceived(worker_id));
if (state_ == OpState::kDeOpIdle) {
*p_buffer = std::move(buf);
@ -407,5 +411,10 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length());
return cache_crc;
}
void DatasetOp::UpdateRepeatAndEpochCounter() {
op_current_repeats_++;
if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++;
}
} // namespace dataset
} // namespace mindspore

@ -70,13 +70,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
public:
static constexpr int32_t kInvalidOperatorId = -1;
// Operator control flags
enum OpControlFlags {
kDeOpNone = 0,
kDeOpRepeated = 1, // Operator is a node in a repeat path
kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop
};
static constexpr int32_t kInfiniteRepeat = -1;
// Flags that control operator runtime behaviours
enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated };
@ -238,13 +232,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return T/F if this is an inlined operator
bool inlined() const { return (oc_queue_size_ == 0); }
/// \brief Setter function
/// \return Sets the control flags
void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); }
/// \brief Setter function, set the number of total repeats for the operator
void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
/// \brief Setter function, set the number of repeats per epoch for the operator
void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
/// \brief Setter function
/// \return Sets the control flags
void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); }
/// \brief Getter function
/// \return The number of required repeats for the operator
int32_t op_total_repeats() { return op_total_repeats_; }
/// \brief Getter function
/// \return The number of required epochs for the operator
int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; }
/// \brief Getter function
/// \return The number of repeats per epoch for the operator
int32_t op_num_repeats_per_epoch() { return op_num_repeats_per_epoch_; }
/// \brief Register the internal worker connectors. No op unless it is a parallel op
/// \return Status
@ -350,6 +354,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return boolean returns true if it's a leaf
bool IsLeaf() { return (child_.empty()); }
/// Checks if an operator has reached its last iteration
/// \return boolean returns true if it's last iteration
bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function
@ -368,6 +376,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return - Status
virtual Status ComputeColMap();
/// Increase op_current_repeats_ by 1 when one repeat finished.
/// If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1.
void UpdateRepeatAndEpochCounter();
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
@ -375,7 +387,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
int32_t operator_id_; // Generated id for the node
ExecutionTree *tree_; // Back pointer to our tree.
OpState state_; // The state of the operator, Running, Idle, Terminated
uint32_t op_ctrl_flags_; // Flags for the operator
int32_t op_total_repeats_; // Required number of repeats for the operator
int32_t op_num_repeats_per_epoch_; // Total number of repeats per epoch for the operator
int32_t op_current_repeats_; // Current number of repeats the operator has handled
int32_t op_current_epochs_; // Current number of epochs the operator has handled
std::unique_ptr<DbConnector> out_connector_; // Output Connector
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
std::mutex column_name_map_mutex_; // For protecting shared access to the column map

@ -30,7 +30,7 @@ namespace dataset {
// The builder "build" method creates the final object.
Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_);
*ptr = std::make_shared<EpochCtrlOp>(build_num_repeats_);
return Status::OK();
}
@ -48,12 +48,12 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common 1-liner info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << " [epochs: " << max_repeats_ << "]\n";
out << " [epochs: " << num_repeats_ << "]\n";
} else {
// Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
@ -88,24 +88,15 @@ Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t
}
Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
UpdateRepeatAndEpochCounter();
repeat_count_++;
MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_
<< ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_;
// If we've reached the requested epoch count, then flag the leaf nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id();
eoe_op->set_control_flag(kDeOpLastRepeat);
}
}
<< ". Max epochs: " << num_repeats_;
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
state_ = OpState::kDeOpIdle;
if (repeat_count_ != max_repeats_) {
if (repeat_count_ != num_repeats_) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());

@ -119,6 +119,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id));
if (in_buffer->eoe()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe));
UpdateRepeatAndEpochCounter();
continue;
} else if (in_buffer->eof()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof));

@ -233,6 +233,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
// Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work
// with Performance Mode design.
if (in_buffer->eoe()) {
UpdateRepeatAndEpochCounter();
// Calling base class EoeReceived to forward eoe buffer.
RETURN_IF_NOT_OK(EoeReceived(worker_id));
// Fetch next data buffer and map job list

@ -76,6 +76,9 @@ Status ProjectOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t w
if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) {
RETURN_IF_NOT_OK(Project(p_buffer));
}
if ((*p_buffer)->eoe()) {
UpdateRepeatAndEpochCounter();
}
return Status::OK();
}

@ -28,10 +28,10 @@
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {}
RepeatOp::Builder::Builder(int32_t count) : build_num_repeats_(count) {}
Status RepeatOp::Builder::SanityCheck() const {
if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) {
if (build_num_repeats_ < kInfiniteRepeat || build_num_repeats_ == 0) {
std::string err_msg("Repeat count must be > 0 or -1.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
@ -41,12 +41,12 @@ Status RepeatOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
Status RepeatOp::Builder::Build(std::shared_ptr<RepeatOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<RepeatOp>(build_max_repeats_);
*ptr = std::make_shared<RepeatOp>(build_num_repeats_);
return Status::OK();
}
// Constructor of the RepeatOp.
RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {}
RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), num_repeats_(count), repeat_count_(0) {}
// Destructor
RepeatOp::~RepeatOp() {}
@ -59,12 +59,12 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common 1-liner info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << " [repeats: " << max_repeats_ << "]\n";
out << " [repeats: " << num_repeats_ << "]\n";
} else {
// Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
@ -109,22 +109,13 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
// Base-class override for handling cases when an eoe is received.
Status RepeatOp::EoeReceived(int32_t worker_id) {
UpdateRepeatAndEpochCounter();
repeat_count_++;
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_
<< ") end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
// If we've reached the requested repeat count, then flag the eoe nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again. This happens in two cases:
// 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
// 2- We are not repeated
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) {
for (auto &eoe_op : eoe_ops_) {
eoe_op->set_control_flag(kDeOpLastRepeat);
}
}
if (repeat_count_ == max_repeats_) {
if (repeat_count_ == num_repeats_) {
repeat_count_ = 0;
state_ = OpState::kDeOpIdle;
return Status::OK();

@ -26,8 +26,6 @@ namespace mindspore {
namespace dataset {
class RepeatOp : public PipelineOp {
public:
static constexpr int32_t kInfiniteRepeat = -1;
// The nested builder class inside of the RepeatOp is used to help manage all of the arguments
// for constructing it. This repeat op is very simple though, so this builder is really just
// provided for a consistent look and feel for creators of Dataset operators overall.
@ -47,7 +45,7 @@ class RepeatOp : public PipelineOp {
Status Build(std::shared_ptr<RepeatOp> *);
protected:
int32_t build_max_repeats_;
int32_t build_num_repeats_;
Status SanityCheck() const;
};
@ -131,13 +129,24 @@ class RepeatOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kRepeatOp; }
/// \brief Getter function
/// \return The number of repeats that the user requested
int32_t num_repeats() { return num_repeats_; }
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
protected:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
// The number of repeats that the user requested.
// Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class.
// For example, for repeat1 op in pipeline tfreader -> repeat1(3) -> repeat2(2) -> epoch ctrl(4),
// num_repeats_ = 3, op_total_repeats_ = 24, op_num_repeats_per_epoch_ = 6.
int32_t num_repeats_;
// A counter for the current number of executed repeats.
// Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class
// because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats.
int32_t repeat_count_;
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
};
} // namespace dataset

@ -293,7 +293,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
}
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
RETURN_IF_NOT_OK(
io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(
@ -310,6 +310,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer));
}
UpdateRepeatAndEpochCounter();
}
}

@ -120,7 +120,7 @@ Status CifarOp::operator()() {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
}
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(
@ -137,6 +137,7 @@ Status CifarOp::operator()() {
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
UpdateRepeatAndEpochCounter();
}
}

@ -271,13 +271,14 @@ Status ClueOp::operator()() {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

@ -167,7 +167,7 @@ Status CocoOp::operator()() {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
}
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe);
std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block)));
@ -184,6 +184,7 @@ Status CocoOp::operator()() {
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
UpdateRepeatAndEpochCounter();
}
}

@ -472,13 +472,14 @@ Status CsvOp::operator()() {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

@ -218,7 +218,7 @@ Status GeneratorOp::operator()() {
MS_LOG(DEBUG) << "Generator operator sends out EOE.";
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
// If last repeat or not repeated, push out EOF and exit master loop
MS_LOG(DEBUG) << "Generator operator sends out EOF.";
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
@ -233,6 +233,7 @@ Status GeneratorOp::operator()() {
// Clear the status of the wait post
wp_.Clear();
}
UpdateRepeatAndEpochCounter();
}
}
return Status::OK();

@ -151,7 +151,7 @@ Status ImageFolderOp::operator()() {
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(keys, IOBlock::kDeIoBlockNone)));
}
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe);
std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block)));
@ -168,6 +168,7 @@ Status ImageFolderOp::operator()() {
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
UpdateRepeatAndEpochCounter();
}
}

@ -112,7 +112,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
}
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(
@ -129,6 +129,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer));
}
UpdateRepeatAndEpochCounter();
}
}

@ -380,7 +380,7 @@ Status MindRecordOp::operator()() {
RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
}
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
RETURN_IF_NOT_OK(
io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(
@ -398,6 +398,7 @@ Status MindRecordOp::operator()() {
RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait());
shard_reader_wait_post_.Clear();
}
UpdateRepeatAndEpochCounter();
}
}

@ -111,7 +111,7 @@ Status MnistOp::operator()() {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
}
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(
@ -128,6 +128,7 @@ Status MnistOp::operator()() {
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
UpdateRepeatAndEpochCounter();
}
}

@ -221,7 +221,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
all_out_.Wait();
// If we are not in a repeat loop, or that was the last repeat already, then setup our exit
// condition from the master loop.
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
*quitting = true;
}
@ -231,6 +231,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
if (last_guy_in) {
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker "
<< eoe_worker_id_;
UpdateRepeatAndEpochCounter();
// Prepare for sync
all_out_.Clear();
// Always flow eoe at the end

@ -421,13 +421,14 @@ Status TextFileOp::operator()() {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);

@ -310,13 +310,14 @@ Status TFReaderOp::operator()() {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save