Support list of IDs as a sampler

pull/11854/head
hesham 4 years ago
parent b1a44e8875
commit 1185218335

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,7 +21,6 @@ import os
import numpy as np
import mindspore._c_dataengine as cde
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
@ -290,8 +289,6 @@ def check_sampler_shuffle_shard_options(param_dict):
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
num_samples = param_dict.get('num_samples')
type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")
if sampler is not None:
if shuffle is not None:
raise RuntimeError("sampler and shuffle cannot be specified at the same time.")

File diff suppressed because it is too large Load Diff

@ -25,6 +25,82 @@ import mindspore._c_dataengine as cde
import mindspore.dataset as ds
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
"""
Create sampler based on user input.
Args:
num_samples (int): Number of samples.
input_sampler (Union[Iterable, Sampler]): Sampler from user.
shuffle (bool): Shuffle.
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
Returns:
Sampler, sampler selected based on user input.
"""
def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if (isinstance(input_sampler, BuiltinSampler) and
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if _is_iterable(input_sampler):
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return [input_sampler]
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle = True
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if shuffle is True:
if num_shards is not None:
# If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle enabled, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if num_shards is not None:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(num_samples=num_samples)
class BuiltinSampler:
"""
Base class for BuiltinSampler.

@ -17,6 +17,7 @@ import pytest
import mindspore.dataset as ds
from mindspore import log as logger
from util import dataset_equal
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
@ -265,6 +266,15 @@ def test_distributed_sampler_invalid_offset():
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
def test_sampler_list():
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=[1, 3, 5])
data21 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(2).skip(1)
data22 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(4).skip(3)
data23 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(6).skip(5)
dataset_equal(data1, data21 + data22 + data23, 0)
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
@ -276,3 +286,4 @@ if __name__ == '__main__':
test_sampler_chain()
test_add_sampler_invalid_input()
test_distributed_sampler_invalid_offset()
test_sampler_list()

Loading…
Cancel
Save