|
|
|
@ -20,6 +20,7 @@
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@ -258,6 +259,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
|
|
|
|
// Set total repeats and total epochs for the DeviceQueueOp
|
|
|
|
|
node->set_total_repeats(num_epochs_);
|
|
|
|
|
node->set_num_repeats_per_epoch(1);
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Adds an operator to the eoe operator stack save area
|
|
|
|
|
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
|
|
|
|
|
op_stack *current_stack = eoe_op_stacks_.top().get();
|
|
|
|
|