!1457 fix 3 bug reports for split

Merge pull request !1457 from Peilin/splitOp-after-testing
pull/1457/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4e8e82f24a

@ -71,7 +71,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============")
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs)
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)

@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status RandomSampler::InitSampler() {
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
rnd_.seed(seed_);
if (replacement_ == false) {
num_samples_ = std::min(num_samples_, num_rows_);
shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {
shuffled_ids_.push_back(i);
}
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
} else {
num_samples_ = std::min(num_samples_, user_num_samples_);
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
}

@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
// Handshake and init child first.
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
Status SubsetSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
num_samples_ = subset_size_;

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
WeightedRandomSampler, SubsetSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show
from .engine.graphdata import GraphData

@ -633,9 +633,9 @@ class Dataset:
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.
Note:
1. Dataset cannot be sharded if split is going to be called.
@ -678,7 +678,8 @@ class Dataset:
ds = copy.deepcopy(self)
if randomize:
# want to shuffle the same way every epoch before split
ds = ds.shuffle()
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
ds = ds.shuffle(10000)
ds.reshuffle_each_epoch = False
if rows_to_skip > 0:
@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
>>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler)
"""
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("new_sampler is not an instance of a sampler.")
self.sampler = self.sampler.child_sampler
self.add_sampler(new_sampler)
@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
def is_sharded(self):
raise NotImplementedError("MappableDataset must implement is_sharded.")
def _get_sampler_dataset_size(self):
if self.sampler is not None:
return self.sampler.get_dataset_size()
return None
@check_split
def split(self, sizes, randomize=True):
@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.
Note:
1. Dataset should not be sharded if split is going to be called. Instead, create a
@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
self.iterator = TupleIterator(self)
class RangeDataset(MappableDataset):
"""
A source dataset that reads and parses datasets stored on disk in a range.
@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
else:
num_samples = self.num_samples
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
return get_num_rows(num_rows, self.num_shards)
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def num_classes(self):
"""
@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
num_samples = self.num_samples
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset):
Return:
Number, number of batches.
"""
return self._dataset_size
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self._dataset_size
return min(rows_from_sampler, self._dataset_size)
# manually set dataset_size as a temporary solution.
def set_dataset_size(self, value):
@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
class_indexing = self.class_indexing
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def num_classes(self):
"""
@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
return get_num_rows(num_rows, self.num_shards)
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
Return:
Number, number of batches.
"""
return num_samples
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return min(rows_from_sampler, self.num_samples)
def is_shuffled(self):
return True
@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset):
Return:
Number, number of batches.
"""
return self.num_samples
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return min(rows_from_sampler, self.num_samples)
def get_class_indexing(self):
"""

@ -114,6 +114,9 @@ class Sampler:
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self._get_indices().size
class BuiltinSampler:
"""
@ -146,6 +149,12 @@ class BuiltinSampler:
def is_sharded(self):
raise NotImplementedError("Sampler must implement is_sharded.")
def get_dataset_size(self):
if self.child_sampler is not None:
return self.child_sampler.get_dataset_size()
return None
class DistributedSampler(BuiltinSampler):
"""
@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples
class SequentialSampler(BuiltinSampler):
"""
@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.subset_size
class SubsetRandomSampler(BuiltinSampler):
"""
@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler):
return cde.MindrecordSubsetRandomSampler(self.indices)
def get_dataset_size(self):
return len(indices)
class WeightedRandomSampler(BuiltinSampler):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler):
return False
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
@ -164,6 +165,35 @@ def test_python_sampler():
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_subset_sampler():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_samples, start_index, subset_size):
sampler = ds.SubsetSampler(start_index, subset_size)
d = ds.ManifestDataset(manifest_file, sampler=sampler)
res = []
for item in d.create_dict_iterator():
res.append(map[(item["image"].shape[0], item["label"].item())])
return res
with pytest.raises(RuntimeError) as info:
test_config(5, 0, 0)
assert "subset_size <= 0" in str(info.value)
assert test_config(5, 0, 1) == [0]
assert test_config(5, 0, 2) == [0, 1]
assert test_config(5, 0, 3) == [0, 1, 2]
assert test_config(5, 0, 4) == [0, 1, 2, 3]
assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
assert test_config(5, 1, 1) == [1]
assert test_config(5, 2, 3) == [2, 3, 4]
assert test_config(5, 3, 2) == [3, 4]
assert test_config(5, 4, 1) == [4]
def test_sampler_chain():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
@ -190,10 +220,26 @@ def test_sampler_chain():
assert test_config(5, 3) == [3]
assert test_config(5, 4) == [4]
def test_add_sampler_invalid_input():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
data1 = ds.ManifestDataset(manifest_file)
with pytest.raises(TypeError) as info:
data1.use_sampler(1)
assert "not an instance of a sampler" in str(info.value)
with pytest.raises(TypeError) as info:
data1.use_sampler("sampler")
assert "not an instance of a sampler" in str(info.value)
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
test_random_sampler_multi_iter(True)
test_sampler_py_api()
test_python_sampler()
test_subset_sampler()
test_sampler_chain()
test_add_sampler_invalid_input()

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save