minddata support voc

pull/1126/head
xiefangqi 5 years ago
parent 25b2424f9b
commit c937bad53f

@ -0,0 +1,10 @@
set(tinyxml2_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 -Wno-unused-result")
set(tinyxml2_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(tinyxml2
VER 8.0.0
LIBS tinyxml2
URL https://github.com/leethomason/tinyxml2/archive/8.0.0.tar.gz
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release
MD5 5dc535c8b34ee621fe2128f072d275b5)
include_directories(${tinyxml2_INC})
add_library(mindspore::tinyxml2 ALIAS tinyxml2::tinyxml2)

@ -56,6 +56,7 @@ if (ENABLE_MINDDATA)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sqlite.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tinyxml2.cmake)
endif()
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)

@ -39,6 +39,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(opencv_LIBPATH ${opencv_LIBPATH}/../bin/)
set(jpeg_turbo_LIBPATH ${jpeg_turbo_LIBPATH}/../bin/)
set(sqlite_LIBPATH ${sqlite_LIBPATH}/../bin/)
set(tinyxml2_LIBPATH ${tinyxml2_LIBPATH}/../bin/)
else ()
set(INSTALL_LIB_DIR "lib")
endif ()
@ -82,6 +83,15 @@ if (ENABLE_MINDDATA)
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
file(GLOB_RECURSE TINYXML2_LIB_LIST
${tinyxml2_LIBPATH}/libtinyxml2*
)
install(
FILES ${TINYXML2_LIB_LIST}
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif ()
if (ENABLE_CPU)

@ -90,7 +90,7 @@ else()
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
endif()
target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs
mindspore::opencv_imgproc)
mindspore::opencv_imgproc mindspore::tinyxml2)
if (ENABLE_GPUQUE)
target_link_libraries(_c_dataengine PRIVATE gpu_queue
${CUDNN_PATH}/lib64/libcudnn.so

@ -898,6 +898,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetTask(ToString(args["task"]));
(void)builder->SetMode(ToString(args["mode"]));
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
@ -912,6 +914,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
} else if (key == "class_indexing") {
(void)builder->SetClassIndex(ToStringMap(value));
}
}
}

@ -55,6 +55,7 @@
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
@ -193,6 +194,13 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
return count;
});
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
return output_class_indexing;
});
}
void bindTensor(py::module *m) {
(void)py::class_<GlobalContext>(*m, "GlobalContext")

File diff suppressed because it is too large Load Diff

@ -16,6 +16,7 @@
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
@ -39,8 +40,12 @@ namespace dataset {
template <typename T>
class Queue;
using Bbox = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
class VOCOp : public ParallelOp, public RandomAccessOp {
public:
enum class TaskType { Segmentation = 0, Detection = 1 };
class Builder {
public:
// Constructor for Builder class of ImageFolderOp
@ -59,6 +64,34 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method.
// @param const std::map<std::string, int32_t> &map - a class name to label map
// @return Builder setter method returns reference to the builder.
Builder &SetClassIndex(const std::map<std::string, int32_t> &map) {
builder_labels_to_read_ = map;
return *this;
}
// Setter method.
// @param const std::string & task_type
// @return Builder setter method returns reference to the builder.
Builder &SetTask(const std::string &task_type) {
if (task_type == "Segmentation") {
builder_task_type_ = TaskType::Segmentation;
} else if (task_type == "Detection") {
builder_task_type_ = TaskType::Detection;
}
return *this;
}
// Setter method.
// @param const std::string & task_mode
// @return Builder setter method returns reference to the builder.
Builder &SetMode(const std::string &task_mode) {
builder_task_mode_ = task_mode;
return *this;
}
// Setter method.
// @param int32_t num_workers
// @return Builder setter method returns reference to the builder.
@ -119,25 +152,33 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
private:
bool builder_decode_;
std::string builder_dir_;
TaskType builder_task_type_;
std::string builder_task_mode_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int32_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
std::shared_ptr<Sampler> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
std::map<std::string, int32_t> builder_labels_to_read_;
};
// Constructor
// @param TaskType task_type - task type of VOC
// @param std::string task_mode - task mode of VOC
// @param std::string folder_path - dir directory of VOC
// @param std::map<std::string, int32_t> class_index - input class-to-index of annotation
// @param int32_t num_workers - number of workers reading images in parallel
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param std::string folder_path - dir directory of VOC
// @param int32_t queue_size - connector queue size
// @param int64_t num_samples - number of samples to read
// @param bool decode - whether to decode images
// @param std::unique_ptr<DataSchema> data_schema - the schema of the VOC dataset
// @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read
VOCOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &folder_path, int32_t queue_size,
int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler);
// Destructor
~VOCOp() = default;
@ -167,6 +208,16 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param int64_t numSamples - samples number of VOCDataset
// @param std::map<std::string, int32_t> *output_class_indexing - output class index of VOCDataset
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples,
std::map<std::string, int32_t> *output_class_indexing);
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
@ -184,19 +235,40 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
// @param const std::string &path - path to the image file
// @param const ColDescriptor &col - contains tensor implementation and datatype
// @param std::shared_ptr<Tensor> tensor - return
// @return Status - The error code return
Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
// @param const std::vector<uint64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// Read image list from ImageSets
// @return Status - The error code return
Status ParseImageIds();
// Read annotation from Annotation folder
// @return Status - The error code return
Status ParseAnnotationIds();
// @param const std::string &path - path to annotation xml
// @return Status - The error code return
Status ParseAnnotationBbox(const std::string &path);
// @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor
// @param std::vector<int64_t> *keys - image id
// @return Status - The error code return
Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);
// Called first when function is called
// @return Status - The error code return
Status LaunchThreadsAndInitOp();
// Reset dataset state
// @return Status - The error code return
Status Reset() override;
bool decode_;
@ -205,6 +277,8 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
int64_t num_rows_;
int64_t num_samples_;
std::string folder_path_;
TaskType task_type_;
std::string task_mode_;
int32_t rows_per_buffer_;
std::shared_ptr<Sampler> sampler_;
std::unique_ptr<DataSchema> data_schema_;
@ -212,6 +286,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
WaitPost wp_;
std::vector<std::string> image_ids_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
std::map<std::string, int32_t> class_index_;
std::map<std::string, int32_t> label_index_;
std::map<std::string, Bbox> label_map_;
};
} // namespace dataset
} // namespace mindspore

@ -34,7 +34,7 @@ import copy
import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, TextFileOp, CBatchInfo
MindRecordOp, TextFileOp, VOCOp, CBatchInfo
from mindspore._c_expression import typing
from mindspore import log as logger
@ -3454,6 +3454,12 @@ class VOCDataset(SourceDataset):
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection"
(default="Segmentation")
mode(str): Set the data list txt file to be readed (default="train")
class_indexing (dict, optional): A str-to-int mapping from label name to index
(default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
@ -3469,27 +3475,41 @@ class VOCDataset(SourceDataset):
argument should be specified only when num_shards is also specified.
Raises:
RuntimeError: If xml of Annotations is a invalid format
RuntimeError: If xml of Annotations loss attribution of "object"
RuntimeError: If xml of Annotations loss attribution of "bndbox"
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If task is not equal 'Segmentation' or 'Detection'.
ValueError: If task equal 'Segmentation' but class_indexing is not None.
ValueError: If txt related to mode is not exist.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Examples:
>>> import mindspore.dataset as ds
>>> dataset_dir = "/path/to/voc_dataset_directory"
>>> # 1) read all VOC dataset samples in dataset_dir with 8 threads in random order:
>>> voc_dataset = ds.VOCDataset(dataset_dir, num_parallel_workers=8)
>>> # 2) read then decode all VOC dataset samples in dataset_dir in sequence:
>>> voc_dataset = ds.VOCDataset(dataset_dir, decode=True, shuffle=False)
>>> # in VOC dataset, each dictionary has keys "image" and "target"
>>> # 1) read VOC data for segmenatation train
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", mode="train")
>>> # 2) read VOC data for detection train
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train")
>>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order:
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", num_parallel_workers=8)
>>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence:
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", decode=True, shuffle=False)
>>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
>>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
"""
@check_vocdataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
def __init__(self, dataset_dir, task="Segmentation", mode="train", class_indexing=None, num_samples=None,
num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.task = task
self.mode = mode
self.class_indexing = class_indexing
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.decode = decode
@ -3500,6 +3520,9 @@ class VOCDataset(SourceDataset):
def get_args(self):
args = super().get_args()
args["dataset_dir"] = self.dataset_dir
args["task"] = self.task
args["mode"] = self.mode
args["class_indexing"] = self.class_indexing
args["num_samples"] = self.num_samples
args["sampler"] = self.sampler
args["decode"] = self.decode
@ -3517,6 +3540,28 @@ class VOCDataset(SourceDataset):
"""
return self.num_samples
def get_class_indexing(self):
"""
Get the class index.
Return:
Dict, A str-to-int mapping from label name to index.
"""
if self.task != "Detection":
raise NotImplementedError()
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
class CelebADataset(SourceDataset):
"""

@ -285,9 +285,9 @@ def create_node(node):
elif dataset_op == 'VOCDataset':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'),
node.get('shard_id'))
pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('mode'), node.get('class_indexing'),
node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'),
node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'CelebADataset':
sampler = construct_sampler(node.get('sampler'))

@ -455,17 +455,44 @@ def check_vocdataset(method):
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_dict = ['class_indexing']
# check dataset_dir; required argument
dataset_dir = param_dict.get('dataset_dir')
if dataset_dir is None:
raise ValueError("dataset_dir is not provided.")
check_dataset_dir(dataset_dir)
# check task; required argument
task = param_dict.get('task')
if task is None:
raise ValueError("task is not provided.")
if not isinstance(task, str):
raise ValueError("task is not str type.")
# check mode; required argument
mode = param_dict.get('mode')
if mode is None:
raise ValueError("mode is not provided.")
if not isinstance(mode, str):
raise ValueError("mode is not str type.")
imagesets_file = ""
if task == "Segmentation":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt")
if param_dict.get('class_indexing') is not None:
raise ValueError("class_indexing is invalid in Segmentation task")
elif task == "Detection":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt")
else:
raise ValueError("Invalid task : " + task)
check_dataset_file(imagesets_file)
check_param_type(nreq_param_int, param_dict, int)
check_param_type(nreq_param_bool, param_dict, bool)
check_param_type(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict)
return method(*args, **kwargs)

@ -64,7 +64,7 @@ SET(DE_UT_SRCS
cifar_op_test.cc
celeba_op_test.cc
take_op_test.cc
text_file_op_test.cc)
text_file_op_test.cc
filter_op_test.cc
)

@ -50,17 +50,170 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
std::shared_ptr<VOCOp> CreateVOC(int64_t num_wrks, int64_t rows, int64_t conns, std::string path,
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
int64_t num_samples = 0, bool decode = false) {
std::shared_ptr<VOCOp> so;
class MindDataTestVOCOp : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestVOCOp, TestVOCDetection) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testVOC2012";
std::string task_type("Detection");
std::string task_mode("train");
std::shared_ptr<VOCOp> my_voc_op;
VOCOp::Builder builder;
Status rc = builder.SetNumWorkers(num_wrks).SetDir(path).SetRowsPerBuffer(rows)
.SetOpConnectorSize(conns).SetSampler(std::move(sampler))
.SetNumSamples(num_samples).SetDecode(decode).Build(&so);
return so;
Status rc = builder.SetDir(dataset_path)
.SetTask(task_type)
.SetMode(task_mode)
.Build(&my_voc_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_voc_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_voc_op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "Launch tree and begin iteration.";
rc = my_tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = my_tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(DEBUG) << "Row display for row #: " << row_count << ".";
//Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(DEBUG) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 9);
}
class MindDataTestVOCSampler : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestVOCOp, TestVOCSegmentation) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testVOC2012";
std::string task_type("Segmentation");
std::string task_mode("train");
std::shared_ptr<VOCOp> my_voc_op;
VOCOp::Builder builder;
Status rc = builder.SetDir(dataset_path)
.SetTask(task_type)
.SetMode(task_mode)
.Build(&my_voc_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_voc_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_voc_op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "Launch tree and begin iteration.";
rc = my_tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = my_tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(DEBUG) << "Row display for row #: " << row_count << ".";
//Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(DEBUG) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 10);
}
TEST_F(MindDataTestVOCOp, TestVOCClassIndex) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testVOC2012";
std::string task_type("Detection");
std::string task_mode("train");
std::map<std::string, int32_t> class_index;
class_index["car"] = 0;
class_index["cat"] = 1;
class_index["train"] = 5;
std::shared_ptr<VOCOp> my_voc_op;
VOCOp::Builder builder;
Status rc = builder.SetDir(dataset_path)
.SetTask(task_type)
.SetMode(task_mode)
.SetClassIndex(class_index)
.Build(&my_voc_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_voc_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_voc_op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "Launch tree and begin iteration.";
rc = my_tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = my_tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(DEBUG) << "Row display for row #: " << row_count << ".";
//Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(DEBUG) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 6);
}

@ -0,0 +1,51 @@
<annotation>
<folder>VOC2012</folder>
<filename>32.jpg</filename>
<source>
<database>simulate VOC2007 Database</database>
<annotation>simulate VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>281</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
<object>
<name>train</name>
<pose>Frontal</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>113</xmin>
<ymin>79</ymin>
<xmax>323</xmax>
<ymax>191</ymax>
</bndbox>
</object>
<object>
<name>train</name>
<pose>Left</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>121</xmin>
<ymin>91</ymin>
<xmax>191</xmax>
<ymax>121</ymax>
</bndbox>
</object>
<object>
<name>car</name>
<pose>Rear</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>195</xmin>
<ymin>155</ymin>
<xmax>235</xmax>
<ymax>235</ymax>
</bndbox>
</object>
</annotation>

@ -1,54 +0,0 @@
<annotation>
<folder>VOC2012</folder>
<filename>27.jpg</filename>
<source>
<database>simulate VOC2007 Database</database>
<annotation>simulate VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>486</width>
<height>500</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>person</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>161</xmin>
<ymin>132</ymin>
<xmax>323</xmax>
<ymax>342</ymax>
</bndbox>
<part>
<name>head</name>
<bndbox>
<xmin>159</xmin>
<ymin>113</ymin>
<xmax>208</xmax>
<ymax>166</ymax>
</bndbox>
</part>
<part>
<name>foot</name>
<bndbox>
<xmin>261</xmin>
<ymin>321</ymin>
<xmax>287</xmax>
<ymax>344</ymax>
</bndbox>
</part>
<part>
<name>foot</name>
<bndbox>
<xmin>329</xmin>
<ymin>317</ymin>
<xmax>330</xmax>
<ymax>366</ymax>
</bndbox>
</part>
</object>
</annotation>

@ -0,0 +1,15 @@
<annotation>
<folder>VOC2012</folder>
<filename>33.jpg</filename>
<source>
<database>simulate VOC2007 Database</database>
<annotation>simulate VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>366</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
</annotation>

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save