From f55cda2c562ef9725b19c95c64cc76b73cf69911 Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Thu, 28 May 2020 09:30:45 +0800 Subject: [PATCH] fix tfreadDataset hang when shard_equal_rows is true --- .../dataset/engine/datasetops/source/tf_reader_op.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 1335344e6d..60adddb4a8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -433,11 +433,13 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector &i_keys) { int64_t start_offset = 0; int64_t end_offset = 0; bool finish = false; + bool end_of_epoch = false; while (!finish) { for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { { std::unique_lock lock(load_io_block_queue_mutex_); if (load_io_block_queue_ == false) { + end_of_epoch = true; break; } } @@ -461,7 +463,8 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector &i_keys) { pre_count += filename_numrows_[file_name]; } } - if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && + !end_of_epoch) { finish = false; } else { finish = true; @@ -478,12 +481,14 @@ Status TFReaderOp::FillIOBlockNoShuffle() { int64_t start_offset = 0; int64_t end_offset = 0; bool finish = false; + bool end_of_epoch = true; while (!finish) { // Iterate over all the keys and add one key to each block. for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { { std::unique_lock lock(load_io_block_queue_mutex_); if (load_io_block_queue_ == false) { + end_of_epoch = true; break; } } @@ -505,7 +510,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() { pre_count += filename_numrows_[file_name]; } } - if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && + !end_of_epoch) { finish = false; } else { finish = true;