!1913 CocoDataset implementation

Merge pull request !1913 from xiefangqi/xfq_support_coco
pull/1913/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b3da41bd7a

@ -23,6 +23,7 @@
#include "dataset/engine/datasetops/source/image_folder_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h" #include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/voc_op.h" #include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/core/tensor.h" #include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h" #include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/source/manifest_op.h" #include "dataset/engine/datasetops/source/manifest_op.h"
@ -65,6 +66,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kMnist, &DEPipeline::ParseMnistOp}, {kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp}, {kManifest, &DEPipeline::ParseManifestOp},
{kVoc, &DEPipeline::ParseVOCOp}, {kVoc, &DEPipeline::ParseVOCOp},
{kCoco, &DEPipeline::ParseCocoOp},
{kCifar10, &DEPipeline::ParseCifar10Op}, {kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op}, {kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp}, {kCelebA, &DEPipeline::ParseCelebAOp},
@ -920,6 +922,16 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
if (args["task"].is_none()) {
std::string err_msg = "Error: No task specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["mode"].is_none()) {
std::string err_msg = "Error: No mode specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>(); std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"])); (void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetTask(ToString(args["task"])); (void)builder->SetTask(ToString(args["task"]));
@ -947,6 +959,47 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return Status::OK(); return Status::OK();
} }
Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["dataset_dir"].is_none()) {
std::string err_msg = "Error: No dataset path specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["annotation_file"].is_none()) {
std::string err_msg = "Error: No annotation_file specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["task"].is_none()) {
std::string err_msg = "Error: No task specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetFile(ToString(args["annotation_file"]));
(void)builder->SetTask(ToString(args["task"]));
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
}
}
}
std::shared_ptr<CocoOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments // Required arguments
if (args["dataset_dir"].is_none()) { if (args["dataset_dir"].is_none()) {

@ -58,6 +58,7 @@ enum OpName {
kMnist, kMnist,
kManifest, kManifest,
kVoc, kVoc,
kCoco,
kCifar10, kCifar10,
kCifar100, kCifar100,
kCelebA, kCelebA,
@ -142,6 +143,8 @@ class DEPipeline {
Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

@ -56,6 +56,7 @@
#include "dataset/engine/jagged_connector.h" #include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/voc_op.h" #include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/gnn/graph.h" #include "dataset/engine/gnn/graph.h"
#include "dataset/kernels/data/to_float16_op.h" #include "dataset/kernels/data/to_float16_op.h"
#include "dataset/text/kernels/jieba_tokenizer_op.h" #include "dataset/text/kernels/jieba_tokenizer_op.h"
@ -214,6 +215,18 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
return output_class_indexing; return output_class_indexing;
}); });
(void)py::class_<CocoOp, DatasetOp, std::shared_ptr<CocoOp>>(*m, "CocoOp")
.def_static("get_class_indexing",
[](const std::string &dir, const std::string &file, const std::string &task) {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing));
return output_class_indexing;
})
.def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) {
int64_t count = 0;
THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count));
return count;
});
} }
void bindTensor(py::module *m) { void bindTensor(py::module *m) {
(void)py::class_<GlobalContext>(*m, "GlobalContext") (void)py::class_<GlobalContext>(*m, "GlobalContext")
@ -607,6 +620,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("MNIST", OpName::kMnist) .value("MNIST", OpName::kMnist)
.value("MANIFEST", OpName::kManifest) .value("MANIFEST", OpName::kManifest)
.value("VOC", OpName::kVoc) .value("VOC", OpName::kVoc)
.value("COCO", OpName::kCoco)
.value("CIFAR10", OpName::kCifar10) .value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100) .value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData) .value("RANDOMDATA", OpName::kRandomData)

@ -13,6 +13,7 @@ add_library(engine-datasetops-source OBJECT
image_folder_op.cc image_folder_op.cc
mnist_op.cc mnist_op.cc
voc_op.cc voc_op.cc
coco_op.cc
manifest_op.cc manifest_op.cc
cifar_op.cc cifar_op.cc
random_data_op.cc random_data_op.cc

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -20,8 +20,8 @@ can also create samplers with this module to sample data.
from .core.configuration import config from .core.configuration import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
Schema, Shuffle, zip, RandomDataset TextFileDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler WeightedRandomSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show from .engine.serializer_deserializer import serialize, deserialize, show
@ -30,5 +30,5 @@ from .engine.graphdata import GraphData
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]

@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip",
"ImageFolderDatasetV2", "MnistDataset", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler",
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]

@ -33,7 +33,7 @@ import copy
import numpy as np import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, TextFileOp, VOCOp, CBatchInfo MindRecordOp, TextFileOp, VOCOp, CocoOp, CBatchInfo
from mindspore._c_expression import typing from mindspore._c_expression import typing
from mindspore import log as logger from mindspore import log as logger
@ -42,8 +42,9 @@ from .iterators import DictIterator, TupleIterator
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, \ check_rename, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, check_split check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\
check_split
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try: try:
@ -3867,10 +3868,14 @@ class VOCDataset(MappableDataset):
""" """
A source dataset for reading and parsing VOC dataset. A source dataset for reading and parsing VOC dataset.
The generated dataset has two columns ['image', 'target']. The generated dataset has two columns :
The shape of both column is [image_size] if decode flag is False, or [H, W, C] task='Detection' : ['image', 'annotation'].
task='Segmentation' : ['image', 'target']
The shape of both column 'image' and 'target' is [image_size] if decode flag is False, or [H, W, C]
otherwise. otherwise.
The type of both tensor is uint8. The type of both tensor 'image' and 'target' is uint8.
The type of tensor 'annotation' is uint32.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior. below shows what input args are allowed and their expected behavior.
@ -4035,6 +4040,163 @@ class VOCDataset(MappableDataset):
return self.sampler.is_sharded() return self.sampler.is_sharded()
class CocoDataset(MappableDataset):
"""
A source dataset for reading and parsing COCO dataset.
CocoDataset support four kinds of task:
2017 Train/Val/Test Detection, Keypoints, Stuff, Panoptic.
The generated dataset has multi-columns :
task = 'Detection' : column [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
['iscrowd', dtype=uint32]].
task = 'Stuff' : column [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd',dtype=uint32]].
task = 'Keypoint' : column [['image', dtype=uint8], ['keypoints', dtype=float32], ['num_keypoints', dtype=uint32]].
task = 'Panoptic' : column [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
['iscrowd', dtype=uint32], ['area', dtype=uint32]].
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
annotation_file (str): Path to the annotation json.
task (str): Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint'
(default='Detection')
num_samples (int, optional): The number of images to be included in the dataset
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
order behavior shown in the table).
decode (bool, optional): Decode the images after reading (default=False).
sampler (Sampler, optional): Object used to choose samples from the dataset
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
RuntimeError: If parse json file failed.
ValueError: If task is not in ['Detection', 'Stuff', 'Panoptic', 'Keypoint'].
ValueError: If annotation_file is not exist.
ValueError: If dataset_dir is not exist.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Examples:
>>> import mindspore.dataset as ds
>>> dataset_dir = "/path/to/coco_dataset_directory/image_folder"
>>> annotation_file = "/path/to/coco_dataset_directory/annotation_folder/annotation.json"
>>> # 1) read COCO data for Detection task
>>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Detection')
>>> # 2) read COCO data for Stuff task
>>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Stuff')
>>> # 3) read COCO data for Panoptic task
>>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Panoptic')
>>> # 4) read COCO data for Keypoint task
>>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Keypoint')
>>> # in COCO dataset, each dictionary has keys "image" and "annotation"
"""
@check_cocodataset
def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.annotation_file = annotation_file
self.task = task
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.decode = decode
self.shuffle_level = shuffle
self.num_shards = num_shards
self.shard_id = shard_id
def get_args(self):
args = super().get_args()
args["dataset_dir"] = self.dataset_dir
args["annotation_file"] = self.annotation_file
args["task"] = self.task
args["num_samples"] = self.num_samples
args["sampler"] = self.sampler
args["decode"] = self.decode
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def get_class_indexing(self):
"""
Get the class index.
Return:
Dict, A str-to-int mapping from label name to index.
"""
if self.task not in {"Detection", "Panoptic"}:
raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.")
class_index = CocoOp.get_class_indexing(self.dataset_dir, self.annotation_file, self.task)
return dict(class_index)
def is_shuffled(self):
if self.shuffle_level is None:
return True
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class CelebADataset(MappableDataset): class CelebADataset(MappableDataset):
""" """
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently. A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently.

@ -165,6 +165,8 @@ class Iterator:
op_type = OpName.MANIFEST op_type = OpName.MANIFEST
elif isinstance(dataset, de.VOCDataset): elif isinstance(dataset, de.VOCDataset):
op_type = OpName.VOC op_type = OpName.VOC
elif isinstance(dataset, de.CocoDataset):
op_type = OpName.COCO
elif isinstance(dataset, de.Cifar10Dataset): elif isinstance(dataset, de.Cifar10Dataset):
op_type = OpName.CIFAR10 op_type = OpName.CIFAR10
elif isinstance(dataset, de.Cifar100Dataset): elif isinstance(dataset, de.Cifar100Dataset):

@ -299,6 +299,12 @@ def create_node(node):
node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'), node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'),
node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'CocoDataset':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), node.get('num_samples'),
node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler,
node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'CelebADataset': elif dataset_op == 'CelebADataset':
sampler = construct_sampler(node.get('sampler')) sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'), pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'),

@ -524,6 +524,49 @@ def check_vocdataset(method):
return new_method return new_method
def check_cocodataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CocoDataset)."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
# check dataset_dir; required argument
dataset_dir = param_dict.get('dataset_dir')
if dataset_dir is None:
raise ValueError("dataset_dir is not provided.")
check_dataset_dir(dataset_dir)
# check annotation_file; required argument
annotation_file = param_dict.get('annotation_file')
if annotation_file is None:
raise ValueError("annotation_file is not provided.")
check_dataset_file(annotation_file)
# check task; required argument
task = param_dict.get('task')
if task is None:
raise ValueError("task is not provided.")
if not isinstance(task, str):
raise ValueError("task is not str type.")
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
raise ValueError("Invalid task type")
check_param_type(nreq_param_int, param_dict, int)
check_param_type(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(*args, **kwargs)
return new_method
def check_celebadataset(method): def check_celebadataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CelebADataset).""" """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""

@ -71,6 +71,7 @@ SET(DE_UT_SRCS
jieba_tokenizer_op_test.cc jieba_tokenizer_op_test.cc
tokenizer_op_test.cc tokenizer_op_test.cc
gnn_graph_test.cc gnn_graph_test.cc
coco_op_test.cc
) )
add_executable(de_ut_tests ${DE_UT_SRCS}) add_executable(de_ut_tests ${DE_UT_SRCS})

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

@ -0,0 +1 @@
{"info": {"description": "COCO 2017 Dataset", "url": "http://cocodataset.org", "version": "1.0", "year": 2017, "contributor": "COCO Consortium", "data_created": "2017/09/01"}, "images":[{"license": 3, "file_name": "000000391895.jpg", "id": 391895},{"license": 3, "file_name": "000000318219.jpg", "id": 318219}],"annotations": [{"segmentation": [[10.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0]], "num_keypoints": 10,"area": 12345,"iscrowd": 0,"keypoints": [244,139,2,0,0,0,226,118,2,0,0,0,154,159,2,143,261,2,135,312,2,271,423,2,184,530,2,261,280,2,347,592,2,0,0,0,123,596,2,0,0,0,0,0,0,0,0,0,0,0,0],"image_id": 318219,"bbox": [40.65,38.8,418.38,601.2],"category_id": 1, "id": 491276},{"segmentation": [[20.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0]], "num_keypoints": 14,"area": 45678,"iscrowd": 0,"keypoints": [368,61,1,369,52,2,0,0,0,382,48,2,0,0,0,368,84,2,435,81,2,362,125,2,446,125,2,360,153,2,0,0,0,397,167,1,439,166,1,369,193,2,461,234,2,361,246,2,474,287,2],"image_id": 391895,"bbox": [339.88,22.16,153.88,300.73],"category_id": 1, "id": 202758}]}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save