|
|
@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
TensorRow row;
|
|
|
|
TensorRow row;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
db->PopRow(&row);
|
|
|
|
db->PopRow(&row);
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
TensorRow row;
|
|
|
|
TensorRow row;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
db->PopRow(&row);
|
|
|
|
db->PopRow(&row);
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
TensorRow row;
|
|
|
|
TensorRow row;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
int epoch = 0;
|
|
|
|
int epoch = 0;
|
|
|
|
while (!db->eoe()) {
|
|
|
|
while (!db->eoe()) {
|
|
|
|
epoch++;
|
|
|
|
epoch++;
|
|
|
@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
db.reset();
|
|
|
|
db.reset();
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
|
|
|
|
ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
|
|
|
@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
TensorRow row;
|
|
|
|
TensorRow row;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
int epoch = 0;
|
|
|
|
int epoch = 0;
|
|
|
|
while (!db->eoe()) {
|
|
|
|
while (!db->eoe()) {
|
|
|
|
epoch++;
|
|
|
|
epoch++;
|
|
|
@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
db.reset();
|
|
|
|
db.reset();
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Without replacement, each sample only drawn once.
|
|
|
|
// Without replacement, each sample only drawn once.
|
|
|
@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
TensorRow row;
|
|
|
|
TensorRow row;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
db->PopRow(&row);
|
|
|
|
db->PopRow(&row);
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
|
|
|
|
|
|
|
|
m_sampler.Reset();
|
|
|
|
m_sampler.Reset();
|
|
|
|
out.clear();
|
|
|
|
out.clear();
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
db->PopRow(&row);
|
|
|
|
db->PopRow(&row);
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
std::unique_ptr<DataBuffer> db;
|
|
|
|
TensorRow row;
|
|
|
|
TensorRow row;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
std::vector<uint64_t> out;
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
db->PopRow(&row);
|
|
|
|
db->PopRow(&row);
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
ASSERT_EQ(num_samples, out.size());
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
|
|
|
|
|
|
|
|
m_sampler.Reset();
|
|
|
|
m_sampler.Reset();
|
|
|
@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|
|
|
freq.resize(total_samples, 0);
|
|
|
|
freq.resize(total_samples, 0);
|
|
|
|
MS_LOG(INFO) << "Resetting sampler";
|
|
|
|
MS_LOG(INFO) << "Resetting sampler";
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
db->PopRow(&row);
|
|
|
|
db->PopRow(&row);
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (const auto &t : row) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
|
|
@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
|
|
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
ASSERT_EQ(db->eoe(), true);
|
|
|
|
}
|
|
|
|
}
|
|
|
|