inject epoch ctrl op in the execution tree and send eos at the end of epoch

pull/3212/head
anzhengqi 5 years ago
parent f30df6e3e8
commit 008b91b2a1

@ -25,6 +25,8 @@
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
@ -84,7 +86,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};
{kClue, &DEPipeline::ParseClueOp},
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
@ -166,8 +169,8 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &
Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); }
// Function to launch the tree execution.
Status DEPipeline::LaunchTreeExec() {
RETURN_IF_NOT_OK(tree_->Prepare());
Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) {
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs));
RETURN_IF_NOT_OK(tree_->Launch());
iterator_ = std::make_unique<DatasetIterator>(tree_);
if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator.");
@ -252,6 +255,16 @@ int DEPipeline::GetRepeatCount() const { return repeat_num_; }
float ToFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
Status DEPipeline::StopSend() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
if (op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "StopSend only supported by DeviceQueueOp");
}
op->StopSend();
return Status::OK();
}
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
@ -804,6 +817,18 @@ Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
Status DEPipeline::ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<EpochCtrlOp> op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(ToInt(args["count"])).Build(&op));
*top = op;
return Status::OK();
}
Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>();
@ -973,8 +998,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetDeviceType(ToString(value));
} else if (key == "device_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "num_batch") {
(void)builder->SetNumBatch(ToInt(value));
} else if (key == "send_epoch_end") {
(void)builder->SetSendEpochEnd(ToBool(value));
}
}
}

@ -70,7 +70,8 @@ enum OpName {
kRandomData,
kTextFile,
kBuildVocab,
kClue
kClue,
kEpochCtrl
};
// The C++ binder class that we expose to the python script.
@ -90,7 +91,7 @@ class DEPipeline {
Status AssignRootNode(const DsOpPtr &dataset_op);
// Function to launch the tree execution.
Status LaunchTreeExec();
Status LaunchTreeExec(int32_t num_epochs);
// Get a row of data as dictionary of column name to the value.
Status GetNextAsMap(py::dict *output);
@ -143,6 +144,10 @@ class DEPipeline {
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
@ -189,6 +194,8 @@ class DEPipeline {
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status StopSend();
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private:

@ -159,7 +159,7 @@ void bindDEPipeline(py::module *m) {
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
.def("SetBatchParameters",
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
.def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
.def("GetNextAsMap",
[](DEPipeline &de) {
py::dict out;
@ -188,6 +188,7 @@ void bindDEPipeline(py::module *m) {
.def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;
@ -999,7 +1000,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue);
.value("CLUE", OpName::kClue)
.value("EPOCHCTRL", OpName::kEpochCtrl);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix)

@ -40,7 +40,9 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
out_map->clear();
TensorRow curr_row;
MS_LOG(INFO) << "get next as map start.";
RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
MS_LOG(INFO) << "fetchNextTensor success.";
// Return empty map if there's no data
if (curr_row.empty()) {
@ -105,7 +107,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again.
if (eof_handled_) {
return Status::OK();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
// Check if we need to get a new DataBuffer to iterate.
@ -119,36 +122,22 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
// handle eoe and eof messages here.
//
// An eoe buffer means we have iterated fully to the end of the tree.
// An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
// all operators.
// An eoe buffer means we have iterated an epoch.
// The next buffer in the pipeline might be an EOF or a databuffer for next epoch
if (curr_buffer_->eoe()) {
MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row.";
// Before returning the last empty vector, fetch the eof buffer which should be the last
// buffer, and then free it.
RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
if (!curr_buffer_->eof()) {
RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
}
eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state
root_->Tree()->SetFinished();
MS_LOG(INFO) << "End of data iteration.";
curr_buffer_.reset(); // explicitly free the eoe buffer
return Status::OK();
}
// An eof buffer means it is the end of execution and all operators are shutting down.
// Because there is no more data to return to the caller, this will change `eof_handled_` state and
// returns status unexpected error.
if (curr_buffer_->eof()) {
// An eof by itself, without being preceded by an eoe, is possible if a repeat operator
// exists below us in the stack. Repeat operator eats eoe's but eventually allows the
// flow of an eof up the pipeline by itself.
eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state
root_->Tree()->SetFinished();
return Status::OK();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
}
@ -208,20 +197,24 @@ Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again.
if (eof_handled_) {
return Status::OK();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
// Check if we need to get a new DataBuffer to iterate.
if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
// GetNextInput() depends on current_op's EoeReceived. So, EOE buffer might be already be handled and
// this child iterator might not see EOE buffer.
RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
// Unlike the DatasetIterator, this child iterator does not quit after eoe.
// Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
// If an eoe is picked up here, we simply return an empty vector and it's up to the
// caller to decide what it wants to do next.
if (curr_buffer_->eoe()) {
MS_LOG(DEBUG) << "Child iterator picked up EOE.";
end_epoch_ = true;
return Status::OK();
} else {
end_epoch_ = false;
}
if (curr_buffer_->eof()) {

@ -144,6 +144,9 @@ class ChildIterator : public IteratorBase {
// @return The string to column id mapping.
std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
// Return T/F if end of epoch
bool end_of_epoch() { return end_epoch_; }
private:
DatasetOp *current_op_; // The parent operator. We consume from it's children.
int32_t child_idx_; // The specific child this iterator will fetch from.

@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
shuffle_op.cc
zip_op.cc
concat_op.cc
epoch_ctrl_op.cc
cache_base_op.cc
cache_lookup_op.cc
cache_op.cc

@ -17,11 +17,13 @@
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include <algorithm>
#include <iomanip>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -202,5 +204,29 @@ BuildVocabOp::Builder::Builder()
builder_num_workers_ = cfg->num_parallel_workers();
builder_connector_size_ = cfg->op_connector_size();
}
// A print method typically used for debugging
void BuildVocabOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <BuildVocabOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCode is needed here to show more info about the op."
<< "\n\n";
}
}
// Pre-Visitor accept method for NodePass
Status BuildVocabOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<BuildVocabOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -131,6 +131,21 @@ class BuildVocabOp : public ParallelOp {
~BuildVocabOp() = default;
/// \brief A print method typically used for debugging
/// \param[out] out The output stream to write output to
/// \param[in] show_all A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \briefStream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param[out] out Reference to the output stream being overloaded
/// \param[in] vop - reference to the BuildVocabOp to display
/// \return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const BuildVocabOp &vop) {
vop.Print(out, false);
return out;
}
Status WorkerEntry(int32_t worker_id) override;
// collect the work product from each worker
@ -152,6 +167,12 @@ class BuildVocabOp : public ParallelOp {
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
private:
const int32_t interval_;
bool special_first_;

@ -96,7 +96,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
}
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
RETURN_IF_NOT_OK(EofReceived(worker_id));
return Status::OK();
}
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
@ -298,5 +298,19 @@ Status CacheMergeOp::EoeReceived(int32_t worker_id) {
}
return Status::OK();
}
// Base-class override for handling cases when an eof is received.
Status CacheMergeOp::EofReceived(int32_t worker_id) {
// If we are not in a repeated path, then the merge op gets a eof by itself, without first
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) {
MS_LOG(DEBUG) << "Cache merge sending eoe";
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
}
MS_LOG(DEBUG) << "Cache merge sending eof";
return DatasetOp::EofReceived(worker_id);
}
} // namespace dataset
} // namespace mindspore

@ -176,6 +176,11 @@ class CacheMergeOp : public ParallelOp {
/// \return Status object
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
Status EofReceived(int32_t worker_id) override;
protected:
Status ComputeColMap() override;

@ -26,6 +26,7 @@
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
@ -102,6 +103,15 @@ Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
}
return Status::OK();
}
// Removes child operator in this operator.
Status DatasetOp::RemoveChildren() {
for (const auto &child : child_) {
child->RemoveParent(this);
}
child_.clear();
return Status::OK();
}
// Adds a parent operator to this operator
void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); }
@ -185,6 +195,12 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
}
}
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> DatasetOp::children() const { return child_; }
// Getter function to get all of our parents.
std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }
// Creates the connector within this operator
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers

@ -76,6 +76,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status eerror code returned
Status Remove();
// Removes child operator in this operator.
Status RemoveChildren();
/// \brief Getter function to get a shared pointer to our child
/// \param[in] child_index An operator can have n children. Indicates which child to return.
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
@ -86,6 +89,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void Parent(DatasetOp **parent, int32_t parent_index) const;
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> children() const;
// Getter function to get all of our parents.
std::vector<DatasetOp *> parents() const;
// Inserts a operator as the parent current op.
// Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op.

@ -25,19 +25,21 @@
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
int32_t op_connector_size, int64_t num_batch)
int32_t op_connector_size, bool send_epoch_end)
: PipelineOp(op_connector_size),
channel_name_(channel_name),
device_type_(device_type),
device_id_(device_id),
prefetch_size_(prefetch_size),
num_batch_(num_batch) {}
send_epoch_end_(send_epoch_end),
stop_send_(false) {}
DeviceQueueOp::~DeviceQueueOp() {}
@ -53,8 +55,7 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size)
: builder_prefetch_size_(prefetch_size),
builder_device_id_(0),
builder_device_type_(DeviceType::CPU),
builder_channel_name_(""),
builder_num_batch_(0) {
builder_channel_name_("") {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
}
@ -64,6 +65,18 @@ Status DeviceQueueOp::EoeReceived(int32_t worker_id) {
return Status::OK();
}
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
// this method checks if the buffer meets the conditions to be sent to TDT
if (buffer->NumRows() != 0) {
TensorRow row;
buffer->GetRow(0, &row);
for (const auto &item : row) {
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
}
}
return Status::OK();
}
Status DeviceQueueOp::operator()() {
TaskManager::FindMe()->Post();
@ -82,23 +95,10 @@ Status DeviceQueueOp::operator()() {
return Status::OK();
}
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
// this method checks if the buffer meets the conditions to be sent to TDT
if (buffer->NumRows() != 0) {
TensorRow row;
buffer->GetRow(0, &row);
for (const auto &item : row) {
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
}
}
return Status::OK();
}
#ifdef ENABLE_TDTQUE
Status DeviceQueueOp::SendDataToAscend() {
MS_LOG(INFO) << "Device queue, sending data to Ascend.";
int64_t total_batch = 0;
bool is_break_loop = false;
double batch_start_time, end_time;
int32_t batch_cost, tdt_cost;
int32_t connector_size = 0;
@ -115,15 +115,20 @@ Status DeviceQueueOp::SendDataToAscend() {
std::unique_ptr<DataBuffer> current_buffer;
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
while (!current_buffer->eof() && !is_break_loop) {
while (!current_buffer->eoe() && !is_break_loop) {
while (!current_buffer->eof()) {
while (!current_buffer->eoe()) {
RETURN_IF_NOT_OK(CheckExceptions(current_buffer));
TensorRow currRow;
for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) {
for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
if (status == TdtStatus::FAILED) {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
}
if (isProfilingEnable) {
@ -140,9 +145,6 @@ Status DeviceQueueOp::SendDataToAscend() {
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size);
}
total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
is_break_loop = true;
}
}
if (isProfilingEnable) {
connector_size = ChildOpConnectorSize();
@ -150,6 +152,19 @@ Status DeviceQueueOp::SendDataToAscend() {
}
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
}
if (current_buffer->eoe() && send_epoch_end_) {
TensorRow currRow;
auto status =
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
if (status == TdtStatus::FAILED) {
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
}
}
if (isProfilingEnable) {
connector_size = ChildOpConnectorSize();
connector_capacity = ChildOpConnectorCapacity();
@ -158,7 +173,7 @@ Status DeviceQueueOp::SendDataToAscend() {
}
tree_->SetFinished();
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
MS_LOG(INFO) << "Device queue total batch is " << total_batch;
return Status::OK();
}
@ -196,9 +211,6 @@ Status DeviceQueueOp::SendDataToGPU() {
}
RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle));
total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
is_break_loop = true;
}
}
if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
@ -211,12 +223,10 @@ Status DeviceQueueOp::SendDataToGPU() {
is_break_loop = true;
}
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
GpuBufferMgr::GetInstance().Close(handle);
GpuBufferMgr::GetInstance().CloseConfirm();
return Status::OK();
}
@ -240,8 +250,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
if (ret == BlockQueueStatus_T::ERROR_INPUT) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it.");
} else {
MS_LOG(WARNING) << "Retry pushing data...";
continue;
if (!stop_send_) {
MS_LOG(WARNING) << "Retry pushing data...";
continue;
}
break;
}
} else {
break;
@ -283,13 +296,11 @@ Status DeviceQueueOp::SendDataToCPU() {
MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << ".";
MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << ".";
total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
break;
}
if (stop_send_) break;
}
}
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
return Status::OK();
}

@ -21,6 +21,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h"
#ifdef ENABLE_TDTQUE
@ -84,8 +85,8 @@ class DeviceQueueOp : public PipelineOp {
return *this;
}
Builder &SetNumBatch(int64_t num_batch) {
builder_num_batch_ = num_batch;
Builder &SetSendEpochEnd(bool send_epoch_end) {
builder_send_epoch_end_ = send_epoch_end;
return *this;
}
@ -94,8 +95,9 @@ class DeviceQueueOp : public PipelineOp {
// to call this Build() method. It will instantiate the DeviceQueueOp
// and return it to caller as a shared pointer.
Status Build(std::shared_ptr<DeviceQueueOp> *ptr) {
*ptr = std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_);
*ptr =
std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
builder_prefetch_size_, builder_op_connector_size_, builder_send_epoch_end_);
return Status::OK();
}
@ -104,14 +106,14 @@ class DeviceQueueOp : public PipelineOp {
int32_t builder_device_id_;
DeviceType builder_device_type_;
std::string builder_channel_name_;
int64_t builder_num_batch_;
int32_t builder_op_connector_size_;
bool builder_send_epoch_end_;
};
// Name: constructor
// Description
DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
int32_t op_connector_size, int64_t num_batch);
int32_t op_connector_size, bool send_epoch_end);
// Name: destructor
// Description
@ -121,6 +123,8 @@ class DeviceQueueOp : public PipelineOp {
const int32_t get_prefetch_size() { return prefetch_size_; }
void StopSend() { stop_send_ = true; }
// Name: Print()
// Description: A function that prints info about the node
void Print(std::ostream &out, // In: The output stream to print to
@ -149,6 +153,7 @@ class DeviceQueueOp : public PipelineOp {
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
Status CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const;
private:
#ifdef ENABLE_TDTQUE
Status SendDataToAscend();
#endif
@ -164,7 +169,8 @@ class DeviceQueueOp : public PipelineOp {
DeviceType device_type_;
const int32_t device_id_;
const int32_t prefetch_size_;
const int64_t num_batch_;
const bool send_epoch_end_;
bool stop_send_;
#ifdef ENABLE_TDTQUE
std::shared_ptr<TdtPlugin> tdtInstancePtr;

@ -0,0 +1,130 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// The builder "build" method creates the final object.
Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_);
return Status::OK();
}
// Constructor
EpochCtrlOp::EpochCtrlOp(int32_t num_epoch) : RepeatOp(num_epoch) { MS_LOG(INFO) << "Welcome to Epoch Ctrl Op."; }
// Destructor
EpochCtrlOp::~EpochCtrlOp() {}
// A print method typically used for debugging
void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <EpochCtrlOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << " [epochs: " << max_repeats_ << "]\n";
} else {
// Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << "\n Operator: " << eoe_ops_[i]->id();
}
} else {
out << " None.";
}
out << "\n\n";
}
}
Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("EpochCtrlOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
// `retry_if_eoe` is false because EpochCtrlOp does not eat EOE.
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, false));
// Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op.
// Other databuffers containing data or EOF will simply be forwarded.
// EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up.
if (buf->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
}
*p_buffer = std::move(buf);
return Status::OK();
}
Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
repeat_count_++;
MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_
<< ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_;
// If we've reached the requested epoch count, then flag the leaf nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id();
eoe_op->set_control_flag(kDeOpLastRepeat);
}
}
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
state_ = OpState::kDeOpIdle;
if (repeat_count_ != max_repeats_) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
}
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status EpochCtrlOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<EpochCtrlOp>(), modified);
}
// Visitor accept method for NodePass
Status EpochCtrlOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->RunOnNode(shared_from_base<EpochCtrlOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#define DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
namespace mindspore {
namespace dataset {
class EpochCtrlOp : public RepeatOp {
public:
class Builder : public RepeatOp::Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of repeats to do
// @return This is a constructor.
explicit Builder(int32_t count) : RepeatOp::Builder(count) {}
// Default destructor
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new EpochCtrlOp object
Status Build(std::shared_ptr<EpochCtrlOp> *);
};
// Contructor
explicit EpochCtrlOp(int32_t num_epoch);
// Destructor
~EpochCtrlOp();
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since EpochCtrlOp is derived from RepeatOp which is an inlined op, getting a buffer from us
// will simply bounce you to get a buffer from our child.
// Epoch Control Op does not eat the EOE, it will pass the EOE to the next op.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_

@ -132,6 +132,7 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
// Invoke a reset against the eoe nodes only.
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
@ -167,8 +168,9 @@ int32_t RepeatOp::num_consumers() const {
Status RepeatOp::Reset() {
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset.";
MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset.";
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
state_ = OpState::kDeOpRunning;

@ -46,7 +46,7 @@ class RepeatOp : public PipelineOp {
// @return shared_ptr to the new RepeatOp object
Status Build(std::shared_ptr<RepeatOp> *);
private:
protected:
int32_t build_max_repeats_;
Status SanityCheck() const;
@ -131,11 +131,11 @@ class RepeatOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return "RepeatOp"; }
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
/// \param[in] eoe_op The input leaf/eoe operator to add to the list
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
private:
protected:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.

@ -132,8 +132,9 @@ Status ZipOp::prepare(TensorQTable *const table) {
if (eof_) {
return Status::OK();
}
// One of our child iterators encounter EOE. Returns and proceed with draining phase.
if (new_row.empty()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!");
return Status::OK();
}
// Pack this first row into our tensor table

@ -23,6 +23,7 @@
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
@ -50,11 +51,11 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
if (op->tree_ == this) {
return Status::OK();
}
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) {
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding && tree_state_ != kDeTStatePrepare) {
std::string err_msg =
"Invalid tree state for adding a node. Current state: " + std::to_string(static_cast<int>(tree_state_)) +
" Expected states: " + std::to_string(static_cast<int>(kDeTStateInit)) + " or " +
std::to_string(static_cast<int>(kDeTStateBuilding));
std::to_string(static_cast<int>(kDeTStateBuilding)) + " or " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg);
}
@ -200,7 +201,9 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// For example, repeatOp inlining
//
// @return Status - The error code return
Status ExecutionTree::Prepare() {
Status ExecutionTree::Prepare(int32_t num_epochs) {
num_epochs_ = num_epochs;
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
@ -222,6 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<InjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
pre_actions.push_back(std::make_unique<CacheTransformPass>());
// Apply pre action passes
@ -278,6 +282,11 @@ Status ExecutionTree::PrepareDeprecated() {
" Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (root_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree.");
}
// Start the recursive prepare
RETURN_IF_NOT_OK(this->PrepareNode(root_));
tree_state_ = kDeTStateReady;

@ -176,7 +176,7 @@ class ExecutionTree {
// For example, repeatOp inlining
//
// @return Status - The error code return
Status Prepare();
Status Prepare(int num_epochs = -1);
// Compulsory transformation/action pre optimization.
// @return Status - The error code return
@ -193,6 +193,7 @@ class ExecutionTree {
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
// @param Total number of epochs that will be run on this tree
// @return Status - The error code return
Status PrepareDeprecated();
@ -231,6 +232,10 @@ class ExecutionTree {
// Optional optimizations status
bool OptimizationEnabled() const { return optimize_; }
// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }
private:
// A helper functions for doing the recursive printing
// @param dataset_op - The dataset op to print
@ -245,6 +250,7 @@ class ExecutionTree {
int32_t id_count_; // Counter for generating operator id's
uint32_t prepare_flags_; // Flags used during tree prepare
TreeState tree_state_; // Tracking the current tree state
int32_t num_epochs_; // Total number of epochs to run for this tree
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool optimize_; // Flag to enable optional optimizations

@ -5,6 +5,7 @@ add_library(engine-opt OBJECT
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/injection_pass.cc
pre/removal_nodes.cc
pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc

@ -16,11 +16,13 @@
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
@ -230,6 +232,11 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
@ -244,5 +251,15 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset
} // namespace mindspore

@ -77,6 +77,10 @@ class CacheMergeOp;
class CacheLookupOp;
class EpochCtrlOp;
class BuildVocabOp;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
@ -190,12 +194,18 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);

@ -20,6 +20,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
namespace mindspore {
namespace dataset {
@ -28,6 +29,9 @@ RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(fa
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
// If we are already repeated, then this is a nested repeat.
if (is_repeated_) {
nested_repeats_++;
@ -36,6 +40,18 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified)
return Status::OK();
}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// EpochCtrl is derived from RepeatOp. Generally it should do the identical setup
// that RepeatOp does. However, epoch control is actually simpler because it can
// only exist as the root node so it doesn't need all the nested code.
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
is_repeated_ = true;
return Status::OK();
}
// Identifies the subtree below this node as being in a cache merge path
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Turn on the flag that we're under a merge op
@ -47,13 +63,24 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
// at this time, so we can pop it to get rid of it.
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
if (!current_stack->empty()) {
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
}
eoe_op_stacks_.pop();
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
// and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed
// from the save area, because the merge op above us may also take action on it later for a different
// case when there is no repeat in the merge leg.
if (is_merge_ && cache_lookup_) {
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
node->AddToEoeList(std::move(cache_lookup_));
@ -65,16 +92,29 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
nested_repeats_--;
}
// If we are not nested, or we were the top-most repeat, now we clear the flag
if (nested_repeats_ == 0) {
} else {
// If we are not nested, or we were the top-most repeat, now we clear the flag
if (nested_repeats_ != 0) {
RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!");
}
is_repeated_ = false;
}
return Status::OK();
}
// Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
is_repeated_ = false;
return Status::OK();
}
// CacheOp removes previous leaf ops and replaces them with itself
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (is_repeated_) {
@ -118,9 +158,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
// Turns off the tracking for operations under merge op
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Setting the flag is needed since we didn't call the base class DatasetOp version
if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated);
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
if (cache_lookup_) {
AddToEOEOpStack(std::move(cache_lookup_));
}
}
cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
is_merge_ = false;
cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed
return Status::OK();
}
@ -135,25 +182,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
} else {
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
// Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that.
}
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
// Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
// add the lookup to the eoe stack
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
return Status::OK();
}
// Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
current_stack->push(dataset_op);
}
// Pops an operator from the eoe operator stack save area
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
if (!eoe_stack_.empty()) {
top_op = eoe_stack_.top();
eoe_stack_.pop();
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
if (current_stack != nullptr && !current_stack->empty()) {
top_op = current_stack->top();
current_stack->pop();
}
return top_op;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save