|
|
@ -22,15 +22,31 @@
|
|
|
|
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
|
|
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace dataset {
|
|
|
|
namespace dataset {
|
|
|
|
|
|
|
|
|
|
|
|
RepeatPass::RepeatPass()
|
|
|
|
RepeatPass::RepeatPass()
|
|
|
|
: num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {}
|
|
|
|
: is_repeated_(false),
|
|
|
|
|
|
|
|
nested_repeats_(0),
|
|
|
|
|
|
|
|
num_repeats_(1),
|
|
|
|
|
|
|
|
num_epochs_(1),
|
|
|
|
|
|
|
|
is_merge_(false),
|
|
|
|
|
|
|
|
is_cached_(false),
|
|
|
|
|
|
|
|
cache_lookup_(nullptr) {}
|
|
|
|
|
|
|
|
|
|
|
|
// Identifies the subtree below this node as being in a repeated path of the tree.
|
|
|
|
// Identifies the subtree below this node as being in a repeated path of the tree.
|
|
|
|
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|
|
|
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|
|
|
|
|
|
|
// Create a new stack for eoe operators and push onto our stack of stacks.
|
|
|
|
|
|
|
|
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
|
|
|
|
|
|
|
|
eoe_op_stacks_.push(std::move(new_stack));
|
|
|
|
|
|
|
|
// If we are already repeated, then this is a nested repeat.
|
|
|
|
|
|
|
|
if (is_repeated_) {
|
|
|
|
|
|
|
|
nested_repeats_++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
is_repeated_ = true;
|
|
|
|
|
|
|
|
|
|
|
|
// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
|
|
|
|
// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
|
|
|
|
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
|
|
|
|
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
|
|
|
|
if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
|
|
|
|
if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
|
|
|
@ -58,7 +74,9 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie
|
|
|
|
// that RepeatOp does. However, epoch control is actually simpler because it can
|
|
|
|
// that RepeatOp does. However, epoch control is actually simpler because it can
|
|
|
|
// only exist as the root node so it doesn't need all the nested code.
|
|
|
|
// only exist as the root node so it doesn't need all the nested code.
|
|
|
|
// Create a new stack for eoe operators and push onto our stack of stacks.
|
|
|
|
// Create a new stack for eoe operators and push onto our stack of stacks.
|
|
|
|
|
|
|
|
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
|
|
|
|
|
|
|
|
eoe_op_stacks_.push(std::move(new_stack));
|
|
|
|
|
|
|
|
is_repeated_ = true;
|
|
|
|
// Get the total number of epochs from the EpochCtrlOp parameter
|
|
|
|
// Get the total number of epochs from the EpochCtrlOp parameter
|
|
|
|
num_epochs_ = node->num_repeats();
|
|
|
|
num_epochs_ = node->num_repeats();
|
|
|
|
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
|
|
|
|
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
|
|
|
@ -85,6 +103,22 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
|
|
|
|
|
|
|
|
|
|
|
// Hooks up any identified eoe nodes under this repeat.
|
|
|
|
// Hooks up any identified eoe nodes under this repeat.
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|
|
|
|
|
|
|
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
|
|
|
|
|
|
|
|
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while (leaf_op != nullptr) {
|
|
|
|
|
|
|
|
node->AddToEoeList(leaf_op);
|
|
|
|
|
|
|
|
leaf_op = PopFromEOEOpStack();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
|
|
|
|
|
|
|
|
// at this time, so we can pop it to get rid of it.
|
|
|
|
|
|
|
|
op_stack *current_stack = eoe_op_stacks_.top().get();
|
|
|
|
|
|
|
|
if (!current_stack->empty()) {
|
|
|
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
eoe_op_stacks_.pop();
|
|
|
|
|
|
|
|
|
|
|
|
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
|
|
|
|
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
|
|
|
|
// and set its total repeats. It is important that the op is removed from the save area,
|
|
|
|
// and set its total repeats. It is important that the op is removed from the save area,
|
|
|
|
// because the merge op above us may also take action on it later for a different case when
|
|
|
|
// because the merge op above us may also take action on it later for a different case when
|
|
|
@ -95,6 +129,18 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|
|
|
cache_lookup_.reset();
|
|
|
|
cache_lookup_.reset();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
|
|
|
|
|
|
|
|
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
|
|
|
|
|
|
|
|
if (nested_repeats_ > 0) {
|
|
|
|
|
|
|
|
AddToEOEOpStack(node);
|
|
|
|
|
|
|
|
nested_repeats_--;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// If we are not nested, or we were the top-most repeat, now we clear the flag
|
|
|
|
|
|
|
|
if (nested_repeats_ != 0) {
|
|
|
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
is_repeated_ = false;
|
|
|
|
|
|
|
|
}
|
|
|
|
if (is_cached_) {
|
|
|
|
if (is_cached_) {
|
|
|
|
AddToCachedOpStack(node);
|
|
|
|
AddToCachedOpStack(node);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -110,6 +156,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|
|
|
|
|
|
|
|
|
|
|
// Hooks up any identified eoe nodes under this repeat.
|
|
|
|
// Hooks up any identified eoe nodes under this repeat.
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
|
|
|
|
|
|
|
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
|
|
|
|
|
|
|
|
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
|
|
|
|
|
|
|
while (leaf_op != nullptr) {
|
|
|
|
|
|
|
|
node->AddToEoeList(leaf_op);
|
|
|
|
|
|
|
|
leaf_op = PopFromEOEOpStack();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
is_repeated_ = false;
|
|
|
|
node->set_total_repeats(num_repeats_);
|
|
|
|
node->set_total_repeats(num_repeats_);
|
|
|
|
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
|
|
|
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
|
|
|
// We finish the walk of this EpochCtrl's descendent nodes.
|
|
|
|
// We finish the walk of this EpochCtrl's descendent nodes.
|
|
|
@ -138,6 +191,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
|
|
|
return Status::OK();
|
|
|
|
return Status::OK();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
|
|
|
|
|
|
|
// If we are in a repeat path, then set our repeated flag
|
|
|
|
|
|
|
|
if (is_repeated_) {
|
|
|
|
|
|
|
|
// if infinite repeat save ourself in a stack for the repeat operator above us
|
|
|
|
|
|
|
|
if (num_repeats_ < 0) {
|
|
|
|
|
|
|
|
AddToEOEOpStack(node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// If we are under a cache op, then save ourselves to the cached op stack.
|
|
|
|
|
|
|
|
if (is_cached_) {
|
|
|
|
|
|
|
|
AddToCachedOpStack(node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set total repeats and total epochs for the node
|
|
|
|
|
|
|
|
node->set_total_repeats(num_repeats_);
|
|
|
|
|
|
|
|
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
|
|
|
|
|
|
|
return Status::OK();
|
|
|
|
|
|
|
|
}
|
|
|
|
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
|
|
|
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
|
|
|
// for use with a controlling repeat above it.
|
|
|
|
// for use with a controlling repeat above it.
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
|
|
@ -190,6 +260,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified
|
|
|
|
return Status::OK();
|
|
|
|
return Status::OK();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Adds an operator to the eoe operator stack save area
|
|
|
|
|
|
|
|
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
|
|
|
|
|
|
|
|
op_stack *current_stack = eoe_op_stacks_.top().get();
|
|
|
|
|
|
|
|
current_stack->push(dataset_op);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Pops an operator from the eoe operator stack save area
|
|
|
|
|
|
|
|
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
|
|
|
|
|
|
|
|
std::shared_ptr<DatasetOp> top_op = nullptr;
|
|
|
|
|
|
|
|
op_stack *current_stack = eoe_op_stacks_.top().get();
|
|
|
|
|
|
|
|
if (current_stack != nullptr && !current_stack->empty()) {
|
|
|
|
|
|
|
|
top_op = current_stack->top();
|
|
|
|
|
|
|
|
current_stack->pop();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return top_op;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Adds an operator to the cached operator stack save area
|
|
|
|
// Adds an operator to the cached operator stack save area
|
|
|
|
void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); }
|
|
|
|
void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); }
|
|
|
|
|
|
|
|
|
|
|
|