diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 54c998a92d..5ff8151c0e 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -391,35 +391,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *in_partitions) { - if (args["partitions"].is_none()) { - std::string err_msg = "Error: partitions is not set (None)"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - py::list list = py::reinterpret_borrow(args["partitions"]); - for (auto l : list) { - if (!l.is_none()) { - in_partitions->push_back(ToInt(l)); +Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded) { + auto sampler = py::reinterpret_borrow(handle); + auto create = sampler.attr("create_for_minddataset"); + auto op = create().cast>(); + std::stack> stack_ops; + while (op != nullptr) { + auto sampler_op = std::dynamic_pointer_cast(op); + if (sampler_op && num_padded > 0) { + sampler_op->SetNumPaddedSamples(num_padded); + stack_ops.push(sampler_op); + } else { + stack_ops.push(op); } + op = op->GetChildOp(); } - - if (in_partitions->size() != 2) { - std::string err_msg = "Error: partitions is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - constexpr int kMaxPartitions = 1024; - if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) { - std::string err_msg = "Error: partitions is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) { - std::string err_msg = "Error: partitions is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); + while (!stack_ops.empty()) { + operators->push_back(stack_ops.top()); + stack_ops.pop(); } - return Status::OK(); } @@ -460,34 +452,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptrSetNumMindRecordWorkers(ToInt(value)); } else if (key == "block_reader" && ToBool(value) == true) { (void)builder->SetBlockReader(); - } else if (key == "shuffle_option" && ToBool(value) == true) { - if (!args["partitions"].is_none()) continue; - uint32_t seed = GetSeed(); - operators.push_back(std::make_shared(seed)); } else if (key == "sampler") { - auto sampler = py::reinterpret_borrow(value); - auto create = sampler.attr("_create_for_minddataset"); - auto op = create().cast>(); - operators.push_back(op); + int num_padded = 0; + if (!args["num_padded"].is_none()) { + num_padded = ToInt(args["num_padded"]); + } + RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); } } } - std::vector in_partitions; - if (!args["partitions"].is_none()) { - auto ret = CheckMindRecordPartitionInfo(args, &in_partitions); - if (Status::OK() != ret) { - return ret; - } - auto shuffle = ToBool(args["shuffle_option"]); - int num_padded = 0; - if (!args["num_padded"].is_none()) { - num_padded = ToInt(args["num_padded"]); - } - operators.push_back( - std::make_shared(in_partitions[0], in_partitions[1], num_padded, shuffle, 0)); - } - if (!operators.empty()) { (void)builder->SetOperators(operators); } diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 493c092b1f..ef594f0521 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -108,10 +109,12 @@ class DEPipeline { Status ParseShuffleOp(const py::dict &args, std::shared_ptr *ptr); - Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector *ptr); - Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *ptr); + Status BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded); + Status ParseMapOp(const py::dict &args, std::shared_ptr *ptr); Status ParseFilterOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 65d0330ff0..b034b618ef 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -71,6 +71,7 @@ #include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_sample.h" +#include "mindrecord/include/shard_sequential_sample.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "pybind11/stl_bind.h" @@ -165,8 +166,8 @@ void bindDatasetOps(py::module *m) { const int64_t num_padded) { int64_t count = 0; std::shared_ptr op; - if (py::hasattr(sampler, "_create_for_minddataset")) { - auto create = sampler.attr("_create_for_minddataset"); + if (py::hasattr(sampler, "create_for_minddataset")) { + auto create = sampler.attr("create_for_minddataset"); op = create().cast>(); } THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); @@ -486,7 +487,9 @@ void bindSamplerOps(py::module *m) { .def("add_child", [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); - (void)py::class_>(*m, "ShardOperator"); + (void)py::class_>(*m, "ShardOperator") + .def("add_child", [](std::shared_ptr self, + std::shared_ptr child) { self->SetChildOp(child); }); (void)py::class_>(*m, "DistributedSampler") .def(py::init()); @@ -518,6 +521,22 @@ void bindSamplerOps(py::module *m) { } })); + (void)py::class_>(*m, "MindrecordDistributedSampler") + .def(py::init()); + + (void)py::class_>( + *m, "MindrecordRandomSampler") + .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { + return std::make_shared(GetSeed(), num_samples, replacement, reshuffle_each_epoch); + })); + + (void)py::class_>(*m, "MindrecordSequentialSampler") + .def(py::init([](int num_samples, int start_index) { + return std::make_shared(num_samples, start_index); + })); + (void)py::class_>(*m, "WeightedRandomSampler") .def(py::init, bool>()); diff --git a/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h index 92866a4b35..cac1ee0ac2 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h @@ -31,6 +31,10 @@ class ShardDistributedSample : public ShardSample { public: ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); + ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); + + void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } + ~ShardDistributedSample() override{}; MSRStatus PreExecute(ShardTask &tasks) override; diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h index 59c77074a1..f33e3db5f4 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_operator.h @@ -17,6 +17,7 @@ #ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ +#include #include "mindrecord/include/shard_task.h" namespace mindspore { @@ -37,6 +38,14 @@ class ShardOperator { } return SUCCESS; } + virtual bool HasChildOp() { return child_op_ != nullptr; } + + virtual MSRStatus SetChildOp(std::shared_ptr child_op) { + if (child_op != nullptr) child_op_ = child_op; + return SUCCESS; + } + + virtual std::shared_ptr GetChildOp() { return child_op_; } virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } @@ -44,7 +53,10 @@ class ShardOperator { virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } - virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } + virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } + + private: + std::shared_ptr child_op_ = nullptr; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 9be017c646..1f2138d6d5 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ #include "mindrecord/include/common/shard_utils.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_column.h" +#include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_index_generator.h" #include "mindrecord/include/shard_operator.h" diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index 111df3bc1a..a32acbff6e 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -48,10 +48,10 @@ class ShardSample : public ShardOperator { int numerator_; int denominator_; int partition_id_; + int no_of_samples_; std::shared_ptr shuffle_op_; private: - int no_of_samples_; std::vector indices_; SamplerType sampler_type_; }; diff --git a/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h new file mode 100644 index 0000000000..a8ee3a36db --- /dev/null +++ b/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ + +#include +#include +#include +#include +#include "mindrecord/include/shard_sample.h" + +namespace mindspore { +namespace mindrecord { +class ShardSequentialSample : public ShardSample { + public: + ShardSequentialSample(int n, int offset); + + ShardSequentialSample(float per, float per_offset); + + ~ShardSequentialSample() override{}; + + MSRStatus Execute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + int offset_; + float per_; + float per_offset_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h index a9c54e6239..adb172bdcc 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h @@ -26,12 +26,20 @@ class ShardShuffle : public ShardOperator { public: explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); + ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, + ShuffleType shuffle_type = kShuffleSample); + ~ShardShuffle() override{}; MSRStatus Execute(ShardTask &tasks) override; + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + private: uint32_t shuffle_seed_; + int64_t no_of_samples_; + bool replacement_; + bool reshuffle_each_epoch_; ShuffleType shuffle_type_; }; } // namespace mindrecord diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 9d6ea969ea..73d297a2af 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -792,24 +792,51 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { } MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &op, int64_t *count, const int num_padded) { + const std::shared_ptr &ops, int64_t *count, const int num_padded) { if (SUCCESS != Init(file_paths, load_dataset)) { return FAILED; } int64_t num_samples = num_rows_; - if (std::dynamic_pointer_cast(op)) { - auto category_op = std::dynamic_pointer_cast(op); - std::string category_field = category_op->GetCategoryField(); - auto num_classes = GetNumClasses(category_field); - num_samples = category_op->GetNumSamples(num_rows_, num_classes); - } else if (std::dynamic_pointer_cast(op)) { - num_samples = op->GetNumSamples(num_rows_, 0); - if (-1 == num_samples) { - MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; - return FAILED; + bool root = true; + std::stack> stack_ops; + std::shared_ptr op(ops); + while (op != nullptr) { + stack_ops.push(op); + op = op->GetChildOp(); + } + while (!stack_ops.empty()) { + op = stack_ops.top(); + stack_ops.pop(); + if (std::dynamic_pointer_cast(op)) { + num_samples = op->GetNumSamples(num_samples, 0); + if (num_padded > 0 && root == true) { + num_samples += num_padded; + MS_LOG(DEBUG) << "Padding samples work on shuffle sampler."; + root = false; + } + } else if (std::dynamic_pointer_cast(op)) { + auto category_op = std::dynamic_pointer_cast(op); + std::string category_field = category_op->GetCategoryField(); + auto num_classes = GetNumClasses(category_field); + num_samples = category_op->GetNumSamples(num_samples, num_classes); + } else if (std::dynamic_pointer_cast(op)) { + if (std::dynamic_pointer_cast(op)) { + auto sampler_op = std::dynamic_pointer_cast(op); + if (root == true) { + sampler_op->SetNumPaddedSamples(num_padded); + num_samples = op->GetNumSamples(num_samples, 0); + if (-1 == num_samples) { + MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; + return FAILED; + } + root = false; + } + } else { + num_samples = op->GetNumSamples(num_samples, 0); + } + } else { + if (num_padded > 0) num_samples += num_padded; } - } else { - if (num_padded > 0) num_samples += num_padded; } *count = num_samples; return SUCCESS; @@ -1385,12 +1412,16 @@ void ShardReader::Reset() { } void ShardReader::ShuffleTask() { + if (block_reader_) return; + // exist shuffle and distributed sampler in ops, skip shuffle + bool has_sharding = false; for (const auto &op : operators_) { - if (block_reader_) { - continue; + if (std::dynamic_pointer_cast(op)) { + has_sharding = true; } - - if (std::dynamic_pointer_cast(op)) { + } + for (const auto &op : operators_) { + if (std::dynamic_pointer_cast(op) && has_sharding == false) { if (SUCCESS != (*op)(tasks_)) { MS_LOG(WARNING) << "Reshuffle reader tasks failed."; } diff --git a/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc index d95ad1f268..b0f51a77c8 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc @@ -31,6 +31,9 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int shuffle_op_ = std::make_shared(seed, kShuffleSample); } +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) + : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} + int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (no_of_padded_samples_ <= 0) { if (dataset_size % denominator_ == 0) { diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc new file mode 100644 index 0000000000..a7fa4e7343 --- /dev/null +++ b/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindrecord/include/shard_sequential_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardSequentialSample::ShardSequentialSample(int n, int offset) + : ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {} + +ShardSequentialSample::ShardSequentialSample(float per, float per_offset) + : ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {} + +int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { + return dataset_size; + } + if (per_ > kEpsilon && per_ <= 1.0f) { + return dataset_size * kEpsilon; + } + return no_of_samples_; +} + +MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { + int total_no = static_cast(tasks.Size()); + int taking; + if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { + taking = total_no; + } else if (per_ > kEpsilon && per_ <= 1.0f) { + taking = total_no * kEpsilon; + } else { + taking = no_of_samples_; + } + + if (tasks.permutation_.empty()) { + ShardTask new_tasks; + total_no = static_cast(tasks.Size()); + for (int i = offset_; i < taking + offset_; ++i) { + new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); + } + std::swap(tasks, new_tasks); + } else { // shuffled + ShardTask new_tasks; + if (taking > static_cast(tasks.permutation_.size())) { + return FAILED; + } + total_no = static_cast(tasks.permutation_.size()); + for (size_t i = offset_; i < taking + offset_; ++i) { + new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + } + std::swap(tasks, new_tasks); + } + return SUCCESS; +} + +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc index d33400ef38..ce7573af46 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc @@ -21,17 +21,52 @@ namespace mindspore { namespace mindrecord { ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) - : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} + : shuffle_seed_(seed), + no_of_samples_(0), + replacement_(false), + reshuffle_each_epoch_(true), + shuffle_type_(shuffle_type) {} + +ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, + ShuffleType shuffle_type) + : shuffle_seed_(seed), + no_of_samples_(no_of_samples), + replacement_(replacement), + reshuffle_each_epoch_(reshuffle_each_epoch), + shuffle_type_(shuffle_type) {} + +int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (replacement_) { + return no_of_samples_ == 0 ? dataset_size : no_of_samples_; + } + return dataset_size; +} MSRStatus ShardShuffle::Execute(ShardTask &tasks) { if (tasks.categories < 1) { return FAILED; } - if (shuffle_type_ == kShuffleSample) { + if (shuffle_type_ == kShuffleSample) { // shuffle each sample if (tasks.permutation_.empty() == true) { tasks.MakePerm(); } - std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + if (replacement_ == true) { + ShardTask new_tasks; + if (no_of_samples_ == 0) { + no_of_samples_ = static_cast(tasks.Size()); + } + if (no_of_samples_ <= 0) { + MS_LOG(ERROR) << "no_of_samples need to be positive."; + return FAILED; + } + new_tasks.task_list_.reserve(no_of_samples_); + for (uint32_t i = 0; i < no_of_samples_; ++i) { + new_tasks.InsertTask(tasks.GetRandomTask()); + } + std::swap(tasks, new_tasks); + } else { + std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + } } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) uint32_t individual_size = tasks.Size() / tasks.categories; std::vector> new_permutations(tasks.categories, std::vector(individual_size)); @@ -46,7 +81,7 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) { } } } - shuffle_seed_++; + if (reshuffle_each_epoch_) shuffle_seed_++; return SUCCESS; } } // namespace mindrecord diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc index 0a8d8e3d43..13b254974d 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_task.cc @@ -72,6 +72,7 @@ std::tuple, std::vector, json> &ShardTa std::uniform_int_distribution<> dis(0, task_list_.size() - 1); return task_list_[dis(gen)]; } + ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { ShardTask res; if (category_tasks.empty()) return res; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 12151d4737..017d7a5058 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1015,10 +1015,8 @@ class Dataset: def get_distribution(output_dataset): dev_id = 0 - if isinstance(output_dataset, (MindDataset)): - return output_dataset.distribution, dev_id if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, - ManifestDataset, MnistDataset, VOCDataset, CelebADataset)): + ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)): sampler = output_dataset.sampler if isinstance(sampler, samplers.DistributedSampler): dev_id = sampler.shard_id @@ -2670,7 +2668,7 @@ class MnistDataset(MappableDataset): return self.sampler.is_sharded() -class MindDataset(SourceDataset): +class MindDataset(MappableDataset): """ A source dataset that reads from shard files and database. @@ -2687,11 +2685,13 @@ class MindDataset(SourceDataset): sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, sampler is exclusive with shuffle and block_reader). Support list: SubsetRandomSampler, - PkSampler. + PkSampler, RandomSampler, SequentialSampler, DistributedSampler. padded_sample (dict, optional): Samples will be appended to dataset, which keys are the same as column_list. num_padded (int, optional): Number of padding samples.Dataset size plus num_padded should be divisible by num_shards. + num_samples (int, optional): The number of samples to be included in the dataset + (default=None, all samples). Raises: ValueError: If num_shards is specified but shard_id is None. @@ -2703,7 +2703,7 @@ class MindDataset(SourceDataset): def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, shard_id=None, block_reader=False, sampler=None, padded_sample=None, - num_padded=None): + num_padded=None, num_samples=None): super().__init__(num_parallel_workers) if isinstance(dataset_file, list): self.load_dataset = False @@ -2712,15 +2712,10 @@ class MindDataset(SourceDataset): self.dataset_file = dataset_file self.columns_list = columns_list self.shuffle_option = shuffle - self.distribution = "" - self.sampler = sampler - - if num_shards is None or shard_id is None: - self.partitions = None - else: - self.partitions = [num_shards, shard_id] + self.num_shards = num_shards + self.shard_id = shard_id - if block_reader is True and self.partitions is not None: + if block_reader is True and num_shards is not None: raise ValueError("block reader not allowed true when use partitions") if block_reader is True and shuffle is True: @@ -2730,25 +2725,21 @@ class MindDataset(SourceDataset): logger.warning("WARN: global shuffle is not used.") if sampler is not None: - if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ - isinstance(sampler, samplers.PKSampler) is False: + if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, + samplers.DistributedSampler, samplers.RandomSampler, + samplers.SequentialSampler)) is False: raise ValueError("the sampler is not supported yet.") + self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.num_samples = num_samples + # sampler exclusive if block_reader is True and sampler is not None: raise ValueError("block reader not allowed true when use sampler") - if shuffle is not None and sampler is not None: - raise ValueError("shuffle not allowed when use sampler") - - if block_reader is False and sampler is None: - self.shuffle_option = not bool(shuffle is False) - if num_padded is None: num_padded = 0 - self.num_shards = num_shards - self.shard_id = shard_id self.block_reader = block_reader self.padded_sample = padded_sample self.num_padded = num_padded @@ -2766,10 +2757,8 @@ class MindDataset(SourceDataset): args["load_dataset"] = self.load_dataset args["columns_list"] = self.columns_list args["shuffle_option"] = self.shuffle_option - args["partitions"] = self.partitions + args["num_samples"] = self.num_samples args["block_reader"] = self.block_reader - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id args["num_padded"] = self.num_padded args["padded_sample"] = padded_sample args["sampler"] = self.sampler @@ -2788,14 +2777,6 @@ class MindDataset(SourceDataset): else: dataset_file = self.dataset_file num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) - if self.partitions is not None and self.partitions[0] > 0: - if num_rows % self.partitions[0] == 0: - num_rows = num_rows // self.partitions[0] - else: - if self.num_padded > 0: - raise RuntimeError( - "Dataset size plus number of padded samples is not divisible by number of shards.") - num_rows = num_rows // self.partitions[0] + 1 return num_rows return self._dataset_size diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 56ef705f60..0e689bdff5 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -141,7 +141,12 @@ class BuiltinSampler: c_child_sampler = None if self.child_sampler is not None: c_child_sampler = self.child_sampler.create() + return c_child_sampler + def create_child_for_minddataset(self): + c_child_sampler = None + if self.child_sampler is not None: + c_child_sampler = self.child_sampler.create_for_minddataset() return c_child_sampler def is_shuffled(self): @@ -262,6 +267,12 @@ class DistributedSampler(BuiltinSampler): c_sampler.add_child(c_child_sampler) return c_sampler + def create_for_minddataset(self): + c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) + c_child_sampler = self.create_child_for_minddataset() + c_sampler.add_child(c_child_sampler) + return c_sampler + def is_shuffled(self): if self.child_sampler is None: return self.shuffle @@ -318,7 +329,7 @@ class PKSampler(BuiltinSampler): self.num_val = num_val self.shuffle = shuffle - self.class_column = class_column # work for minddataset + self.class_column = class_column # work for minddataset super().__init__(num_samples) def create(self): @@ -340,12 +351,14 @@ class PKSampler(BuiltinSampler): return self.child_sampler.is_sharded() - def _create_for_minddataset(self): + def create_for_minddataset(self): if not self.class_column or not isinstance(self.class_column, str): raise ValueError("class_column should be a not empty string value, \ but got class_column={}".format(class_column)) - return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) - + c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) + c_child_sampler = self.create_child_for_minddataset() + c_sampler.add_child(c_child_sampler) + return c_sampler class RandomSampler(BuiltinSampler): """ @@ -390,6 +403,13 @@ class RandomSampler(BuiltinSampler): c_sampler.add_child(c_child_sampler) return c_sampler + def create_for_minddataset(self): + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) + c_child_sampler = self.create_child_for_minddataset() + c_sampler.add_child(c_child_sampler) + return c_sampler + def is_shuffled(self): return True @@ -440,6 +460,14 @@ class SequentialSampler(BuiltinSampler): c_sampler.add_child(c_child_sampler) return c_sampler + def create_for_minddataset(self): + start_index = self.start_index if self.start_index is not None else 0 + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) + c_child_sampler = self.create_child_for_minddataset() + c_sampler.add_child(c_child_sampler) + return c_sampler + def is_shuffled(self): if self.child_sampler is None: return False @@ -501,8 +529,11 @@ class SubsetRandomSampler(BuiltinSampler): return self.child_sampler.is_sharded() - def _create_for_minddataset(self): - return cde.MindrecordSubsetRandomSampler(self.indices) + def create_for_minddataset(self): + c_sampler = cde.MindrecordSubsetRandomSampler(self.indices) + c_child_sampler = self.create_child_for_minddataset() + c_sampler.add_child(c_child_sampler) + return c_sampler def get_num_samples(self): num_samples = super().get_num_samples() diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 4e6087b9da..8fcefc5889 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -17,6 +17,7 @@ This is the test module for mindrecord """ import os import pytest +import numpy as np import mindspore.dataset as ds from mindspore import log as logger @@ -64,10 +65,12 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): assert data_set.get_dataset_size() == 6 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) logger.info("-------------- item[file_name]: \ {}------------------------".format(to_str(item["file_name"]))) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 @@ -82,12 +85,14 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): assert data_set.get_dataset_size() == 6 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) logger.info("-------------- item[data]: \ {}------------------------".format(item["data"][:10])) logger.info("-------------- item[file_name]: \ {}------------------------".format(to_str(item["file_name"]))) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 @@ -102,10 +107,12 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): assert data_set.get_dataset_size() == 9 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) logger.info("-------------- item[file_name]: \ {}------------------------".format(to_str(item["file_name"]))) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 @@ -119,10 +126,12 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): assert data_set.get_dataset_size() == 15 num_iter = 0 for item in data_set.create_dict_iterator(): - logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) logger.info("-------------- item[file_name]: \ {}------------------------".format(to_str(item["file_name"]))) - logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 @@ -219,7 +228,6 @@ def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): - """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] num_readers = 4 indices = [1, 2, 4, -1, -2] @@ -241,6 +249,344 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): assert num_iter == 5 +def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.RandomSampler() + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + new_dataset = [] + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + new_dataset.append(item['file_name']) + assert num_iter == 10 + assert new_dataset != [x['file_name'] for x in data] + +def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.RandomSampler() + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 10 + ds1 = data_set.repeat(3) + num_iter = 0 + epoch1_dataset = [] + epoch2_dataset = [] + epoch3_dataset = [] + for item in ds1.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + if num_iter <= 10: + epoch1_dataset.append(item['file_name']) + elif num_iter <= 20: + epoch2_dataset.append(item['file_name']) + else: + epoch3_dataset.append(item['file_name']) + assert num_iter == 30 + assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) + assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) + assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) + +def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.RandomSampler(replacement=True, num_samples=5) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 5 + + +def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.SequentialSampler(1, 4) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 4 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array( + data[num_iter+1]['file_name'], dtype='S') + num_iter += 1 + assert num_iter == 4 + + +def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.SequentialSampler(2, 10) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + dataset_size = data_set.get_dataset_size() + assert dataset_size == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array( + data[(num_iter + 2) % dataset_size]['file_name'], dtype='S') + num_iter += 1 + assert num_iter == 10 + + +def test_cv_minddataset_split_basic(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, + num_readers, shuffle=False) + d1, d2 = d.split([8, 2], randomize=False) + assert d.get_dataset_size() == 10 + assert d1.get_dataset_size() == 8 + assert d2.get_dataset_size() == 2 + num_iter = 0 + for item in d1.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array(data[num_iter]['file_name'], + dtype='S') + num_iter += 1 + assert num_iter == 8 + num_iter = 0 + for item in d2.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], + dtype='S') + num_iter += 1 + assert num_iter == 2 + + +def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, + num_readers, shuffle=False) + d1, d2 = d.split([0.8, 0.2], randomize=False) + assert d.get_dataset_size() == 10 + assert d1.get_dataset_size() == 8 + assert d2.get_dataset_size() == 2 + num_iter = 0 + for item in d1.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array( + data[num_iter]['file_name'], dtype='S') + num_iter += 1 + assert num_iter == 8 + num_iter = 0 + for item in d2.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], + dtype='S') + num_iter += 1 + assert num_iter == 2 + + +def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, + num_readers, shuffle=False) + d1, d2 = d.split([0.41, 0.59], randomize=False) + assert d.get_dataset_size() == 10 + assert d1.get_dataset_size() == 4 + assert d2.get_dataset_size() == 6 + num_iter = 0 + for item in d1.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array( + data[num_iter]['file_name'], dtype='S') + num_iter += 1 + assert num_iter == 4 + num_iter = 0 + for item in d2.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array(data[num_iter + 4]['file_name'], + dtype='S') + num_iter += 1 + assert num_iter == 6 + + +def test_cv_minddataset_split_deterministic(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + num_readers = 4 + d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, + num_readers, shuffle=False) + # should set seed to avoid data overlap + ds.config.set_seed(111) + d1, d2 = d.split([0.8, 0.2]) + assert d.get_dataset_size() == 10 + assert d1.get_dataset_size() == 8 + assert d2.get_dataset_size() == 2 + + d1_dataset = [] + d2_dataset = [] + num_iter = 0 + for item in d1.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + d1_dataset.append(item['file_name']) + num_iter += 1 + assert num_iter == 8 + num_iter = 0 + for item in d2.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + d2_dataset.append(item['file_name']) + num_iter += 1 + assert num_iter == 2 + inter_dataset = [x for x in d1_dataset if x in d2_dataset] + assert inter_dataset == [] # intersection of d1 and d2 + + +def test_cv_minddataset_split_sharding(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, + num_readers, shuffle=False) + # should set seed to avoid data overlap + ds.config.set_seed(111) + d1, d2 = d.split([0.8, 0.2]) + assert d.get_dataset_size() == 10 + assert d1.get_dataset_size() == 8 + assert d2.get_dataset_size() == 2 + distributed_sampler = ds.DistributedSampler(2, 0) + d1.use_sampler(distributed_sampler) + assert d1.get_dataset_size() == 4 + + num_iter = 0 + d1_shard1 = [] + for item in d1.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + d1_shard1.append(item['file_name']) + assert num_iter == 4 + assert d1_shard1 != [x['file_name'] for x in data[0:4]] + + distributed_sampler = ds.DistributedSampler(2, 1) + d1.use_sampler(distributed_sampler) + assert d1.get_dataset_size() == 4 + + d1s = d1.repeat(3) + epoch1_dataset = [] + epoch2_dataset = [] + epoch3_dataset = [] + num_iter = 0 + for item in d1s.create_dict_iterator(): + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + if num_iter <= 4: + epoch1_dataset.append(item['file_name']) + elif num_iter <= 8: + epoch2_dataset.append(item['file_name']) + else: + epoch3_dataset.append(item['file_name']) + assert len(epoch1_dataset) == 4 + assert len(epoch2_dataset) == 4 + assert len(epoch3_dataset) == 4 + inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset] + assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2 + assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) + assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) + assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) + + def get_data(dir_name, sampler=False): """ usage: get data from imagenet dataset