fix num samples in pk sampler

pull/4220/head
liyong 5 years ago
parent 4276050f24
commit 7341421d3b

@ -48,12 +48,12 @@ PYBIND_REGISTER(
ShardPkSample, 1, ([](const py::module *m) { ShardPkSample, 1, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
*m, "MindrecordPkSampler") *m, "MindrecordPkSampler")
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) {
if (shuffle == true) { if (shuffle == true) {
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(), return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
GetSeed()); GetSeed(), num_samples);
} else { } else {
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal); return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
} }
})); }));
})); }));

@ -29,19 +29,23 @@ namespace mindspore {
namespace mindrecord { namespace mindrecord {
class ShardPkSample : public ShardCategory { class ShardPkSample : public ShardCategory {
public: public:
ShardPkSample(const std::string &category_field, int64_t num_elements); ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples);
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, int64_t num_samples);
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed,
int64_t num_samples);
~ShardPkSample() override{}; ~ShardPkSample() override{};
MSRStatus SufExecute(ShardTask &tasks) override; MSRStatus SufExecute(ShardTask &tasks) override;
int64_t GetNumSamples() const { return num_samples_; }
private: private:
bool shuffle_; bool shuffle_;
std::shared_ptr<ShardShuffle> shuffle_op_; std::shared_ptr<ShardShuffle> shuffle_op_;
int64_t num_samples_;
}; };
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

@ -49,6 +49,7 @@
#include "minddata/mindrecord/include/shard_error.h" #include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_index_generator.h" #include "minddata/mindrecord/include/shard_index_generator.h"
#include "minddata/mindrecord/include/shard_operator.h" #include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_reader.h" #include "minddata/mindrecord/include/shard_reader.h"
#include "minddata/mindrecord/include/shard_sample.h" #include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h" #include "minddata/mindrecord/include/shard_shuffle.h"

@ -53,7 +53,8 @@ class ShardTask {
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask();
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements); static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
int64_t num_samples);
uint32_t categories; uint32_t categories;

@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
std::string category_field = category_op->GetCategoryField(); std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(category_field); auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_samples, num_classes); num_samples = category_op->GetNumSamples(num_samples, num_classes);
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
if (tmp != 0) {
num_samples = std::min(num_samples, tmp);
}
}
} else if (std::dynamic_pointer_cast<ShardSample>(op)) { } else if (std::dynamic_pointer_cast<ShardSample>(op)) {
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op); auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
auto categories = category_op->GetCategories(); auto categories = category_op->GetCategories();
int64_t num_elements = category_op->GetNumElements(); int64_t num_elements = category_op->GetNumElements();
int64_t num_samples = 0;
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
if (num_samples < 0) {
MS_LOG(ERROR) << "Parameter num_samples is not positive or zero";
return FAILED;
}
}
if (num_elements <= 0) { if (num_elements <= 0) {
MS_LOG(ERROR) << "Parameter num_element is not positive"; MS_LOG(ERROR) << "Parameter num_element is not positive";
return FAILED; return FAILED;
@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
} }
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
} }
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
if (SUCCESS != (*category_op)(tasks_)) { if (SUCCESS != (*category_op)(tasks_)) {
return FAILED; return FAILED;
} }

@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples)
: ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), shuffle_(false) {} : ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true),
shuffle_(false),
num_samples_(num_samples) {}
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories,
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} int64_t num_samples)
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false), num_samples_(num_samples) {}
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories,
uint32_t seed) uint32_t seed, int64_t num_samples)
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true), num_samples_(num_samples) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
} }

@ -86,7 +86,8 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
return task_list_[dis(gen)]; return task_list_[dis(gen)];
} }
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
int64_t num_samples) {
ShardTask res; ShardTask res;
if (category_tasks.empty()) return res; if (category_tasks.empty()) return res;
auto total_categories = category_tasks.size(); auto total_categories = category_tasks.size();
@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
for (uint32_t i = 1; i < total_categories; i++) { for (uint32_t i = 1; i < total_categories; i++) {
minTasks = std::min(minTasks, category_tasks[i].Size()); minTasks = std::min(minTasks, category_tasks[i].Size());
} }
int64_t count = 0;
for (uint32_t task_no = 0; task_no < minTasks; task_no++) { for (uint32_t task_no = 0; task_no < minTasks; task_no++) {
for (uint32_t i = 0; i < total_categories; i++) { for (uint32_t i = 0; i < total_categories; i++) {
if (num_samples != 0 && count == num_samples) break;
res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no)))); res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no))));
count++;
} }
} }
} else { } else {
@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
if (num_elements != std::numeric_limits<int64_t>::max()) { if (num_elements != std::numeric_limits<int64_t>::max()) {
maxTasks = static_cast<decltype(maxTasks)>(num_elements); maxTasks = static_cast<decltype(maxTasks)>(num_elements);
} }
int64_t count = 0;
for (uint32_t i = 0; i < total_categories; i++) { for (uint32_t i = 0; i < total_categories; i++) {
for (uint32_t j = 0; j < maxTasks; j++) { for (uint32_t j = 0; j < maxTasks; j++) {
if (num_samples != 0 && count == num_samples) break;
res.InsertTask(category_tasks[i].GetRandomTask()); res.InsertTask(category_tasks[i].GetRandomTask());
count++;
} }
} }
} }

@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler):
if not self.class_column or not isinstance(self.class_column, str): if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \ raise ValueError("class_column should be a not empty string value, \
but got class_column={}".format(class_column)) but got class_column={}".format(class_column))
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler

@ -104,7 +104,7 @@ class TFRecordToMR:
Args: Args:
source (str): the TFRecord file to be transformed. source (str): the TFRecord file to be transformed.
destination (str): the MindRecord file path to tranform into. destination (str): the MindRecord file path to tranform into.
feature_dict (dict): a dictionary than states the feature type, i.e. feature_dict (dict): a dictionary that states the feature type, i.e.
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
"yyyy": tf.io.FixedLenFeature([], tf.int64)} "yyyy": tf.io.FixedLenFeature([], tf.int64)}

@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
auto column_list = std::vector<std::string>{"file_name", "label"}; auto column_list = std::vector<std::string>{"file_name", "label"};
std::vector<std::shared_ptr<ShardOperator>> ops; std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardPkSample>("label", 2)); ops.push_back(std::make_shared<ShardPkSample>("label", 2, 0));
ShardReader dataset; ShardReader dataset;
dataset.Open({file_name},true, 4, column_list, ops); dataset.Open({file_name},true, 4, column_list, ops);
@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
auto column_list = std::vector<std::string>{"file_name", "label"}; auto column_list = std::vector<std::string>{"file_name", "label"};
std::vector<std::shared_ptr<ShardOperator>> ops; std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0)); ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0, 0));
ShardReader dataset; ShardReader dataset;
dataset.Open({file_name},true, 4, column_list, ops); dataset.Open({file_name},true, 4, column_list, ops);
@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
} }
dataset.Finish(); dataset.Finish();
ASSERT_TRUE(i == 6); ASSERT_TRUE(i == 6);
} // namespace mindrecord }
TEST_F(TestShardOperator, TestShardCategory) { TEST_F(TestShardOperator, TestShardCategory) {
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));

@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1 num_iter += 1
def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
"""tutorial for cv minderdataset.""" """tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
logger.info( logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1 num_iter += 1
assert num_iter == 9
def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(3, None, True, 'label', 5)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 5
def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(3, None, True, 'label', 10)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 9
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 9
def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file):
"""tutorial for cv minderdataset.""" """tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
logger.info( logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1 num_iter += 1
assert num_iter == 15
def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(5, None, True, 'label', 20)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 15
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 15
def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(5, None, True, 'label', 10)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):

Loading…
Cancel
Save