!4149 Fix bug in counting epochs when DeviceQueue is used

Merge pull request !4149 from h.farahat/to_device_bug
pull/4149/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0df5a56159

@ -20,6 +20,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
namespace mindspore {
@ -258,6 +259,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
return Status::OK();
}
Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
// Set total repeats and total epochs for the DeviceQueueOp
node->set_total_repeats(num_epochs_);
node->set_num_repeats_per_epoch(1);
return Status::OK();
}
// Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
op_stack *current_stack = eoe_op_stacks_.top().get();

@ -92,6 +92,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
/// \brief Set the epoch count for DeviceQueue
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it.
/// \param[in] node The node being visited

Loading…
Cancel
Save