!621 [MD] adjust mindrecord ut

Merge pull request !621 from liyong126/mindrecord_ut
pull/621/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2440cea732

@ -346,7 +346,8 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg;
return;
}
MS_LOG(INFO) << "Get" << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
std::lock_guard<std::mutex> lck(shard_locker_);
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
categories.emplace(columns[i][0]);
}

File diff suppressed because it is too large Load Diff

@ -17,6 +17,7 @@
#ifndef TESTS_MINDRECORD_UT_UT_COMMON_H_
#define TESTS_MINDRECORD_UT_UT_COMMON_H_
#include <dirent.h>
#include <fstream>
#include <string>
#include <vector>
@ -25,7 +26,9 @@
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "mindrecord/include/shard_index.h"
#include "mindrecord/include/shard_header.h"
#include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_writer.h"
using json = nlohmann::json;
using std::ifstream;
using std::pair;
@ -40,11 +43,10 @@ class Common : public testing::Test {
std::string install_root;
// every TEST_F macro will enter one
void SetUp();
virtual void SetUp();
void TearDown();
virtual void TearDown();
static void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num);
};
} // namespace UT
@ -55,6 +57,21 @@ class Common : public testing::Test {
///
/// return the formatted string
const std::string FormatInfo(const std::string &message, uint32_t message_total_length = 128);
void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num);
void LoadDataFromImageNet(const std::string &directory, std::vector<json> &json_buffer, const int max_num);
int Img2DataUint8(const std::vector<std::string> &img_absolute_path, std::vector<std::vector<uint8_t>> &bin_data);
int GetAbsoluteFiles(std::string directory, std::vector<std::string> &files_absolute_path);
void ShardWriterImageNet();
void ShardWriterImageNetOneSample();
void ShardWriterImageNetOpenForAppend(string filename);
} // namespace mindrecord
} // namespace mindspore
#endif // TESTS_MINDRECORD_UT_UT_COMMON_H_

@ -29,7 +29,6 @@
#include "mindrecord/include/shard_statistics.h"
#include "securec.h"
#include "ut_common.h"
#include "ut_shard_writer_test.h"
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
@ -43,7 +42,7 @@ class TestShard : public UT::Common {
};
TEST_F(TestShard, TestShardSchemaPart) {
TestShardWriterImageNet();
ShardWriterImageNet();
MS_LOG(INFO) << FormatInfo("Test schema");
@ -55,6 +54,12 @@ TEST_F(TestShard, TestShardSchemaPart) {
ASSERT_TRUE(schema != nullptr);
MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " <<
common::SafeCStr(schema->GetSchema().dump());
for (int i = 1; i <= 4; i++) {
string filename = std::string("./imagenet.shard0") + std::to_string(i);
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
remove(common::SafeCStr(filename));
remove(common::SafeCStr(db_name));
}
}
TEST_F(TestShard, TestStatisticPart) {
@ -128,6 +133,5 @@ TEST_F(TestShard, TestShardHeaderPart) {
ASSERT_EQ(resFields, fields);
}
TEST_F(TestShard, TestShardWriteImage) { MS_LOG(INFO) << FormatInfo("Test writer"); }
} // namespace mindrecord
} // namespace mindspore

@ -53,38 +53,6 @@ class TestShardIndexGenerator : public UT::Common {
TestShardIndexGenerator() {}
};
/*
TEST_F(TestShardIndexGenerator, GetField) {
MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field");
int max_num = 1;
string input_path1 = install_root + "/test/testCBGData/data/annotation.data";
std::vector<json> json_buffer1; // store the image_raw_meta.data
Common::LoadData(input_path1, json_buffer1, max_num);
MS_LOG(INFO) << "Fetch fields: ";
for (auto &j : json_buffer1) {
auto v_name = ShardIndexGenerator::GetField("anno_tool", j);
auto v_attr_name = ShardIndexGenerator::GetField("entity_instances.attributes.attr_name", j);
auto v_entity_name = ShardIndexGenerator::GetField("entity_instances.entity_name", j);
vector<string> names = {"\"CVAT\""};
for (unsigned int i = 0; i != names.size(); i++) {
ASSERT_EQ(names[i], v_name[i]);
}
vector<string> attr_names = {"\"脸部评分\"", "\"特征点\"", "\"points_example\"", "\"polyline_example\"",
"\"polyline_example\""};
for (unsigned int i = 0; i != attr_names.size(); i++) {
ASSERT_EQ(attr_names[i], v_attr_name[i]);
}
vector<string> entity_names = {"\"276点人脸\"", "\"points_example\"", "\"polyline_example\"",
"\"polyline_example\""};
for (unsigned int i = 0; i != entity_names.size(); i++) {
ASSERT_EQ(entity_names[i], v_entity_name[i]);
}
}
}
*/
TEST_F(TestShardIndexGenerator, TakeFieldType) {
MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type");

@ -40,6 +40,17 @@ namespace mindrecord {
class TestShardOperator : public UT::Common {
public:
TestShardOperator() {}
void SetUp() override { ShardWriterImageNet(); }
void TearDown() override {
for (int i = 1; i <= 4; i++) {
string filename = std::string("./imagenet.shard0") + std::to_string(i);
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
remove(common::SafeCStr(filename));
remove(common::SafeCStr(db_name));
}
}
};
TEST_F(TestShardOperator, TestShardSampleBasic) {
@ -165,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
auto x = dataset.GetNext();
if (x.empty()) break;
std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++;
}
dataset.Finish();
@ -191,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
if (x.empty()) break;
std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++;
}
dataset.Finish();

@ -37,6 +37,16 @@ namespace mindrecord {
class TestShardReader : public UT::Common {
public:
TestShardReader() {}
void SetUp() override { ShardWriterImageNet(); }
void TearDown() override {
for (int i = 1; i <= 4; i++) {
string filename = std::string("./imagenet.shard0") + std::to_string(i);
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
remove(common::SafeCStr(filename));
remove(common::SafeCStr(db_name));
}
}
};
TEST_F(TestShardReader, TestShardReaderGeneral) {
@ -51,8 +61,8 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto& j : x) {
for (auto& item : std::get<1>(j).items()) {
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
}
}
@ -74,8 +84,8 @@ TEST_F(TestShardReader, TestShardReaderSample) {
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto& j : x) {
for (auto& item : std::get<1>(j).items()) {
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
}
}
@ -99,8 +109,8 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
while (true) {
auto x = dataset.GetBlockNext();
if (x.empty()) break;
for (auto& j : x) {
for (auto& item : std::get<1>(j).items()) {
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
}
}
@ -119,8 +129,8 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto& j : x) {
for (auto& item : std::get<1>(j).items()) {
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
}
}
@ -140,8 +150,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto& j : x) {
for (auto& item : std::get<1>(j).items()) {
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
}
}
@ -169,9 +179,9 @@ TEST_F(TestShardReader, TestShardVersion) {
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto& j : x) {
for (auto &j : x) {
MS_LOG(INFO) << "result size: " << std::get<0>(j).size();
for (auto& item : std::get<1>(j).items()) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump());
}
}
@ -201,8 +211,8 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto& j : x) {
for (auto& item : std::get<1>(j).items()) {
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump());
}
}

@ -33,15 +33,25 @@
#include "mindrecord/include/shard_segment.h"
#include "ut_common.h"
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
namespace mindspore {
namespace mindrecord {
class TestShardSegment : public UT::Common {
public:
TestShardSegment() {}
void SetUp() override { ShardWriterImageNet(); }
void TearDown() override {
for (int i = 1; i <= 4; i++) {
string filename = std::string("./imagenet.shard0") + std::to_string(i);
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
remove(common::SafeCStr(filename));
remove(common::SafeCStr(db_name));
}
}
};
TEST_F(TestShardSegment, TestShardSegment) {

File diff suppressed because it is too large Load Diff

@ -1,26 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TESTS_MINDRECORD_UT_SHARDWRITER_H
#define TESTS_MINDRECORD_UT_SHARDWRITER_H
namespace mindspore {
namespace mindrecord {
void TestShardWriterImageNet();
} // namespace mindrecord
} // namespace mindspore
#endif // TESTS_MINDRECORD_UT_SHARDWRITER_H
Loading…
Cancel
Save