diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h index a9507f0fb5..0c3b168e21 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "minddata/mindrecord/include/common/shard_utils.h" @@ -38,10 +39,16 @@ class ShardTask { void MakePerm(); - void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label); + inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label); - void InsertTask(std::tuple, std::vector, json> task); + inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, + const std::vector &offset, const json &label); + + inline void InsertTask(std::tuple, std::vector, json> task); + + inline void InsertTask(const uint32_t &i, + std::tuple, std::vector, json> task); void PopBack(); @@ -56,12 +63,41 @@ class ShardTask { static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements, int64_t num_samples); + inline void ResizeTask(const uint32_t &size); + uint32_t categories; std::vector permutation_; std::vector, std::vector, json>> task_list_; }; + +inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label) { + MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id + << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; + task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); +} + +inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, + const std::vector &offset, const json &label) { + task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; +} + +inline void ShardTask::InsertTask(std::tuple, std::vector, json> task) { + MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) + << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() + << ", size of task_list_: " << task_list_.size() << "."; + + task_list_.push_back(std::move(task)); +} + +inline void ShardTask::InsertTask(const uint32_t &i, + std::tuple, std::vector, json> task) { + task_list_[i] = std::move(task); +} + +inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index f6b2ac856f..002f9c8c3e 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -14,6 +14,9 @@ * limitations under the License. */ +#include +#include + #include "minddata/mindrecord/include/shard_distributed_sample.h" #include "minddata/mindrecord/include/shard_reader.h" #include "utils/ms_utils.h" @@ -1036,15 +1039,37 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector(ret) != SUCCESS) { return FAILED; } - auto offsets = std::get<1>(ret); - auto local_columns = std::get<2>(ret); + auto &offsets = std::get<1>(ret); + auto &local_columns = std::get<2>(ret); if (shard_count_ <= kMaxFileCount) { + int sample_count = 0; for (int shard_id = 0; shard_id < shard_count_; shard_id++) { - for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { - tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], - std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, - local_columns[shard_id][i]); - } + sample_count += offsets[shard_id].size(); + } + MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset."; + + // Init the tasks_ size + tasks_.ResizeTask(sample_count); + + // Init the task threads, maybe use ThreadPool is better + std::vector init_tasks_thread(shard_count_); + + uint32_t current_offset = 0; + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() { + auto offset = current_offset; + for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { + tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], + std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, + local_columns[shard_id][i]); + offset++; + } + }); + current_offset += offsets[shard_id].size(); + } + + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + init_tasks_thread[shard_id].join(); } } else { return FAILED; diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc index bfacc90ce6..e760920400 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc @@ -44,21 +44,6 @@ void ShardTask::MakePerm() { } } -void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id - << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; - task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); -} - -void ShardTask::InsertTask(std::tuple, std::vector, json> task) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) - << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() - << ", size of task_list_: " << task_list_.size() << "."; - - task_list_.push_back(std::move(task)); -} - void ShardTask::PopBack() { task_list_.pop_back(); } uint32_t ShardTask::Size() const { return static_cast(task_list_.size()); }