|
|
|
@ -21,6 +21,7 @@
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "minddata/mindrecord/include/common/shard_utils.h"
|
|
|
|
|
|
|
|
|
@ -38,10 +39,16 @@ class ShardTask {
|
|
|
|
|
|
|
|
|
|
void MakePerm();
|
|
|
|
|
|
|
|
|
|
void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
|
|
|
|
const json &label);
|
|
|
|
|
inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
|
|
|
|
const json &label);
|
|
|
|
|
|
|
|
|
|
void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
|
|
|
|
inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
|
|
|
|
const std::vector<uint64_t> &offset, const json &label);
|
|
|
|
|
|
|
|
|
|
inline void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
|
|
|
|
|
|
|
|
|
inline void InsertTask(const uint32_t &i,
|
|
|
|
|
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
|
|
|
|
|
|
|
|
|
void PopBack();
|
|
|
|
|
|
|
|
|
@ -56,12 +63,41 @@ class ShardTask {
|
|
|
|
|
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
|
|
|
|
|
int64_t num_samples);
|
|
|
|
|
|
|
|
|
|
inline void ResizeTask(const uint32_t &size);
|
|
|
|
|
|
|
|
|
|
uint32_t categories;
|
|
|
|
|
|
|
|
|
|
std::vector<int> permutation_;
|
|
|
|
|
|
|
|
|
|
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
|
|
|
|
const json &label) {
|
|
|
|
|
MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id
|
|
|
|
|
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
|
|
|
|
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
|
|
|
|
const std::vector<uint64_t> &offset, const json &label) {
|
|
|
|
|
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
|
|
|
|
MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task))
|
|
|
|
|
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
|
|
|
|
|
<< ", size of task_list_: " << task_list_.size() << ".";
|
|
|
|
|
|
|
|
|
|
task_list_.push_back(std::move(task));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void ShardTask::InsertTask(const uint32_t &i,
|
|
|
|
|
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
|
|
|
|
task_list_[i] = std::move(task);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
|
|
|
|
|
} // namespace mindrecord
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
|
|
|
|
|