|
|
|
@ -30,9 +30,9 @@
|
|
|
|
|
#include "mindrecord/include/shard_shuffle.h"
|
|
|
|
|
#include "ut_common.h"
|
|
|
|
|
|
|
|
|
|
using mindspore::MsLogLevel::INFO;
|
|
|
|
|
using mindspore::ExceptionType::NoExceptionType;
|
|
|
|
|
using mindspore::LogStream;
|
|
|
|
|
using mindspore::ExceptionType::NoExceptionType;
|
|
|
|
|
using mindspore::MsLogLevel::INFO;
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace mindrecord {
|
|
|
|
@ -65,31 +65,31 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
|
|
|
|
|
ASSERT_TRUE(i <= kSampleCount);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
|
|
|
|
|
// MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
|
|
|
|
//
|
|
|
|
|
// std::string file_name = "./imagenet.shard01";
|
|
|
|
|
// auto column_list = std::vector<std::string>{"file_name"};
|
|
|
|
|
//
|
|
|
|
|
// const int kNum = 5;
|
|
|
|
|
// const int kDen = 0;
|
|
|
|
|
// std::vector<std::shared_ptr<ShardOperator>> ops;
|
|
|
|
|
// ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
|
|
|
|
//
|
|
|
|
|
// ShardReader dataset;
|
|
|
|
|
// dataset.Open(file_name, 4, column_list, ops);
|
|
|
|
|
// dataset.Launch();
|
|
|
|
|
//
|
|
|
|
|
// int i = 0;
|
|
|
|
|
// while (true) {
|
|
|
|
|
// auto x = dataset.GetNext();
|
|
|
|
|
// if (x.empty()) break;
|
|
|
|
|
// MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
|
|
|
|
// i++;
|
|
|
|
|
// }
|
|
|
|
|
// dataset.Finish();
|
|
|
|
|
// ASSERT_TRUE(i <= 5);
|
|
|
|
|
// }
|
|
|
|
|
TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
|
|
|
|
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
|
|
|
|
|
|
|
|
|
std::string file_name = "./imagenet.shard01";
|
|
|
|
|
auto column_list = std::vector<std::string>{"file_name"};
|
|
|
|
|
|
|
|
|
|
const int kNum = 5;
|
|
|
|
|
const int kDen = 0;
|
|
|
|
|
std::vector<std::shared_ptr<ShardOperator>> ops;
|
|
|
|
|
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
|
|
|
|
|
|
|
|
|
ShardReader dataset;
|
|
|
|
|
dataset.Open(file_name, 4, column_list, ops);
|
|
|
|
|
dataset.Launch();
|
|
|
|
|
|
|
|
|
|
int i = 0;
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
dataset.Finish();
|
|
|
|
|
ASSERT_TRUE(i <= 5);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestShardOperator, TestShardSampleRatio) {
|
|
|
|
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
|
|
|
@ -117,7 +117,6 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
|
|
|
|
|
ASSERT_TRUE(i <= 10);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(TestShardOperator, TestShardSamplePartition) {
|
|
|
|
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
|
|
|
|
std::string file_name = "./imagenet.shard01";
|
|
|
|
@ -170,8 +169,8 @@ TEST_F(TestShardOperator, TestShardCategory) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
|
|
|
@ -199,8 +198,8 @@ TEST_F(TestShardOperator, TestShardShuffle) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
dataset.Finish();
|
|
|
|
@ -224,8 +223,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
dataset.Finish();
|
|
|
|
@ -251,8 +250,8 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
dataset.Finish();
|
|
|
|
@ -278,8 +277,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
dataset.Finish();
|
|
|
|
@ -307,8 +306,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
|
|
|
|
|
auto y = compare_dataset.GetNext();
|
|
|
|
@ -342,8 +341,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
|
|
|
@ -376,8 +375,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
|
|
|
|
category_no++;
|
|
|
|
@ -410,8 +409,8 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
|
|
|
@ -448,8 +447,8 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
|
|
|
|
|
while (true) {
|
|
|
|
|
auto x = dataset.GetNext();
|
|
|
|
|
if (x.empty()) break;
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
|
|
|
|
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
|
|
|
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
|
|
|
|
i++;
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
|
|
|
|