Support list of IDs as a sampler

pull/11854/head
hesham 5 years ago
parent b1a44e8875
commit 1185218335

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -21,7 +21,6 @@ import os
import numpy as np import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0 # POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1 POS_INT_MIN = 1
@ -290,8 +289,6 @@ def check_sampler_shuffle_shard_options(param_dict):
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
num_samples = param_dict.get('num_samples') num_samples = param_dict.get('num_samples')
type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")
if sampler is not None: if sampler is not None:
if shuffle is not None: if shuffle is not None:
raise RuntimeError("sampler and shuffle cannot be specified at the same time.") raise RuntimeError("sampler and shuffle cannot be specified at the same time.")

File diff suppressed because it is too large Load Diff

@ -25,6 +25,82 @@ import mindspore._c_dataengine as cde
import mindspore.dataset as ds import mindspore.dataset as ds
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
"""
Create sampler based on user input.
Args:
num_samples (int): Number of samples.
input_sampler (Union[Iterable, Sampler]): Sampler from user.
shuffle (bool): Shuffle.
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
Returns:
Sampler, sampler selected based on user input.
"""
def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if (isinstance(input_sampler, BuiltinSampler) and
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if _is_iterable(input_sampler):
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return [input_sampler]
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle = True
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if shuffle is True:
if num_shards is not None:
# If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle enabled, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if num_shards is not None:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(num_samples=num_samples)
class BuiltinSampler: class BuiltinSampler:
""" """
Base class for BuiltinSampler. Base class for BuiltinSampler.

@ -17,6 +17,7 @@ import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from util import dataset_equal
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
@ -265,6 +266,15 @@ def test_distributed_sampler_invalid_offset():
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value) assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
def test_sampler_list():
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=[1, 3, 5])
data21 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(2).skip(1)
data22 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(4).skip(3)
data23 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(6).skip(5)
dataset_equal(data1, data21 + data22 + data23, 0)
if __name__ == '__main__': if __name__ == '__main__':
test_sequential_sampler(True) test_sequential_sampler(True)
test_random_sampler(True) test_random_sampler(True)
@ -276,3 +286,4 @@ if __name__ == '__main__':
test_sampler_chain() test_sampler_chain()
test_add_sampler_invalid_input() test_add_sampler_invalid_input()
test_distributed_sampler_invalid_offset() test_distributed_sampler_invalid_offset()
test_sampler_list()

Loading…
Cancel
Save