fix num_sample in sequentialSampler and randomSampler

pull/7690/head
liyong 4 years ago
parent eaa3fe98ed
commit ee042b90f7

@ -35,33 +35,33 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c
if (per_ > kEpsilon && per_ <= 1.0f) { if (per_ > kEpsilon && per_ <= 1.0f) {
return dataset_size * kEpsilon; return dataset_size * kEpsilon;
} }
return no_of_samples_; return std::min(static_cast<int64_t>(no_of_samples_), dataset_size);
} }
MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
int total_no = static_cast<int>(tasks.Size()); int64_t total_no = static_cast<int64_t>(tasks.Size());
int taking; int64_t taking;
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
taking = total_no; taking = total_no;
} else if (per_ > kEpsilon && per_ <= 1.0f) { } else if (per_ > kEpsilon && per_ <= 1.0f) {
taking = total_no * kEpsilon; taking = total_no * kEpsilon;
} else { } else {
taking = no_of_samples_; taking = std::min(static_cast<int64_t>(no_of_samples_), total_no);
} }
if (tasks.permutation_.empty()) { if (tasks.permutation_.empty()) {
ShardTask new_tasks; ShardTask new_tasks;
total_no = static_cast<int>(tasks.Size()); total_no = static_cast<int64_t>(tasks.Size());
for (int i = offset_; i < taking + offset_; ++i) { for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
} }
std::swap(tasks, new_tasks); std::swap(tasks, new_tasks);
} else { // shuffled } else { // shuffled
ShardTask new_tasks; ShardTask new_tasks;
if (taking > static_cast<int>(tasks.permutation_.size())) { if (taking > static_cast<int64_t>(tasks.permutation_.size())) {
return FAILED; return FAILED;
} }
total_no = static_cast<int>(tasks.permutation_.size()); total_no = static_cast<int64_t>(tasks.permutation_.size());
for (size_t i = offset_; i < taking + offset_; ++i) { for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
} }

@ -39,7 +39,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (replacement_) { if (replacement_) {
return no_of_samples_ == 0 ? dataset_size : no_of_samples_; return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
} }
return dataset_size; return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
} }
MSRStatus ShardShuffle::Execute(ShardTask &tasks) { MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
@ -67,6 +67,14 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
std::swap(tasks, new_tasks); std::swap(tasks, new_tasks);
} else { } else {
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
auto total_no = static_cast<int64_t>(tasks.Size());
if (no_of_samples_ > 0 && no_of_samples_ < total_no) {
ShardTask new_tasks;
for (size_t i = 0; i < no_of_samples_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i));
}
std::swap(tasks, new_tasks);
}
} }
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
uint32_t individual_size = tasks.Size() / tasks.categories; uint32_t individual_size = tasks.Size() / tasks.categories;

@ -208,7 +208,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) {
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0"; std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::vector<std::string> file_list = {file_path1}; std::vector<std::string> file_list = {file_path1};
// Check sequential sampler, output number is 10, with duplicate samples(a little weird, wait to fix) // Check sequential sampler, output number is 5
std::shared_ptr<Dataset> ds1 = MindData(file_list, {}, SequentialSampler(0, 10)); std::shared_ptr<Dataset> ds1 = MindData(file_list, {}, SequentialSampler(0, 10));
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -229,7 +229,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) {
EXPECT_NE(ds5, nullptr); EXPECT_NE(ds5, nullptr);
std::vector<std::shared_ptr<Dataset>> ds = {ds1, ds2, ds3, ds4, ds5}; std::vector<std::shared_ptr<Dataset>> ds = {ds1, ds2, ds3, ds4, ds5};
std::vector<int32_t> expected_samples = {10, 5, 2, 3, 3}; std::vector<int32_t> expected_samples = {5, 5, 2, 3, 3};
for (int32_t i = 0; i < ds.size(); i++) { for (int32_t i = 0; i < ds.size(); i++) {
// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset

@ -412,6 +412,46 @@ def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
num_iter += 1 num_iter += 1
assert num_iter == 5 assert num_iter == 5
def test_cv_minddataset_random_sampler_replacement_false_1(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler(replacement=False, num_samples=2)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 2
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 2
def test_cv_minddataset_random_sampler_replacement_false_2(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler(replacement=False, num_samples=20)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file): def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True) data = get_data(CV_DIR_NAME, True)
@ -437,7 +477,7 @@ def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
assert num_iter == 4 assert num_iter == 4
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): def test_cv_minddataset_sequential_sampler_offeset(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True) data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
@ -461,6 +501,30 @@ def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
num_iter += 1 num_iter += 1
assert num_iter == 10 assert num_iter == 10
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.SequentialSampler(2, 20)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
dataset_size = data_set.get_dataset_size()
assert dataset_size == 10
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_split_basic(add_and_remove_cv_file): def test_cv_minddataset_split_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True) data = get_data(CV_DIR_NAME, True)

Loading…
Cancel
Save