new dataset api -- Coco

pull/3860/head
xiefangqi 5 years ago
parent ccb80694fd
commit bb7bda1d6b

@ -22,6 +22,7 @@
#include "minddata/dataset/engine/dataset_iterator.h"
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
@ -106,6 +107,15 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::s
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a CocoDataset.
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, bool decode, std::shared_ptr<SamplerObj> sampler) {
auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a ImageFolderDataset.
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
@ -384,6 +394,97 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
return node_ops;
}
// Constructor for CocoDataset
CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
bool decode, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
bool CocoDataset::ValidateParams() {
Path dir(dataset_dir_);
if (!dir.IsDirectory()) {
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
return false;
}
Path annotation_file(annotation_file_);
if (!annotation_file.Exists()) {
MS_LOG(ERROR) << "annotation_file is invalid or not exist";
return false;
}
std::set<std::string> task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"};
auto task_iter = task_list.find(task_);
if (task_iter == task_list.end()) {
MS_LOG(ERROR) << "Invalid task type";
return false;
}
return true;
}
// Function to build CocoDataset
std::vector<std::shared_ptr<DatasetOp>> CocoDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if (sampler_ == nullptr) {
sampler_ = CreateDefaultSampler();
}
CocoOp::TaskType task_type;
if (task_ == "Detection") {
task_type = CocoOp::TaskType::Detection;
} else if (task_ == "Stuff") {
task_type = CocoOp::TaskType::Stuff;
} else if (task_ == "Keypoint") {
task_type = CocoOp::TaskType::Keypoint;
} else if (task_ == "Panoptic") {
task_type = CocoOp::TaskType::Panoptic;
}
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor(std::string("image"), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
switch (task_type) {
case CocoOp::TaskType::Detection:
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
break;
case CocoOp::TaskType::Stuff:
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("segmentation"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
break;
case CocoOp::TaskType::Keypoint:
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("keypoints"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("num_keypoints"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
break;
case CocoOp::TaskType::Panoptic:
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
break;
default:
MS_LOG(ERROR) << "CocoDataset::Build : Invalid task type";
return {};
}
std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
node_ops.push_back(op);
return node_ops;
}
ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
bool recursive, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing)

@ -43,6 +43,7 @@ class SamplerObj;
// Datasets classes (in alphabetical order)
class Cifar10Dataset;
class Cifar100Dataset;
class CocoDataset;
class ImageFolderDataset;
class MnistDataset;
class VOCDataset;
@ -75,6 +76,26 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a CocoDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
/// ['iscrowd', dtype=uint32]].
/// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]].
/// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
/// ['num_keypoints', dtype=uint32]].
/// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
/// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]].
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] annotation_file Path to the annotation json
/// \param[in] task Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint'
/// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task = "Detection", bool decode = false,
std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
/// All images within one folder have the same label
@ -298,6 +319,31 @@ class Cifar100Dataset : public Dataset {
std::shared_ptr<SamplerObj> sampler_;
};
class CocoDataset : public Dataset {
public:
/// \brief Constructor
CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, bool decode,
std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~CocoDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::string dataset_dir_;
std::string annotation_file_;
std::string task_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
/// \class ImageFolderDataset
/// \brief A Dataset derived class to represent ImageFolder dataset
class ImageFolderDataset : public Dataset {

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save