From ad91e6e51d6c6c555774706fe073469fe715031d Mon Sep 17 00:00:00 2001 From: hesham Date: Sat, 8 Aug 2020 01:49:21 -0400 Subject: [PATCH] - Fix bug in counting epochs when DeviceQueue is used --- .../ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc | 8 ++++++++ .../ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc index aac0eaa2e9..d737a0fa1b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -20,6 +20,7 @@ #include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" namespace mindspore { @@ -258,6 +259,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified return Status::OK(); } +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Set total repeats and total epochs for the DeviceQueueOp + node->set_total_repeats(num_epochs_); + node->set_num_repeats_per_epoch(1); + return Status::OK(); +} + // Adds an operator to the eoe operator stack save area void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { op_stack *current_stack = eoe_op_stacks_.top().get(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h index 1e865eadac..082f8e2af3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -92,6 +92,12 @@ class RepeatPass : public NodePass { /// \return Status The error code return Status RunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Set the epoch count for DeviceQueue + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up /// for use with a controlling repeat above it. /// \param[in] node The node being visited