|
|
@ -82,10 +82,15 @@ namespace dataset {
|
|
|
|
|
|
|
|
|
|
|
|
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
|
|
|
|
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
|
|
|
|
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
|
|
|
|
.def("SetNumWorkers",
|
|
|
|
.def("set_num_workers",
|
|
|
|
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
|
|
|
|
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
|
|
|
|
return num_workers ? self->SetNumWorkers(*num_workers) : self;
|
|
|
|
return num_workers ? self->SetNumWorkers(*num_workers) : self;
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
.def("set_cache_client",
|
|
|
|
|
|
|
|
[](std::shared_ptr<DatasetNode> self) {
|
|
|
|
|
|
|
|
std::shared_ptr<DatasetCache> dc = nullptr;
|
|
|
|
|
|
|
|
return self->SetDatasetCache(dc);
|
|
|
|
|
|
|
|
})
|
|
|
|
.def(
|
|
|
|
.def(
|
|
|
|
"Zip",
|
|
|
|
"Zip",
|
|
|
|
[](std::shared_ptr<DatasetNode> self, py::list datasets) {
|
|
|
|
[](std::shared_ptr<DatasetNode> self, py::list datasets) {
|
|
|
@ -109,10 +114,9 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
|
|
|
|
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
|
|
|
|
"to create a CelebANode")
|
|
|
|
"to create a CelebANode")
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode,
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode,
|
|
|
|
py::list extensions, std::shared_ptr<CacheClient> cc) {
|
|
|
|
py::list extensions) {
|
|
|
|
auto celebA =
|
|
|
|
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
|
|
|
|
std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
|
|
|
|
toStringSet(extensions), nullptr);
|
|
|
|
toStringSet(extensions), toDatasetCache(std::move(cc)));
|
|
|
|
|
|
|
|
THROW_IF_ERROR(celebA->ValidateParams());
|
|
|
|
THROW_IF_ERROR(celebA->ValidateParams());
|
|
|
|
return celebA;
|
|
|
|
return celebA;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -121,10 +125,8 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
|
|
|
|
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
|
|
|
|
"to create a Cifar10Node")
|
|
|
|
"to create a Cifar10Node")
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler,
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
|
|
|
std::shared_ptr<CacheClient> cc) {
|
|
|
|
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
|
|
|
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler),
|
|
|
|
|
|
|
|
toDatasetCache(std::move(cc)));
|
|
|
|
|
|
|
|
THROW_IF_ERROR(cifar10->ValidateParams());
|
|
|
|
THROW_IF_ERROR(cifar10->ValidateParams());
|
|
|
|
return cifar10;
|
|
|
|
return cifar10;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -133,23 +135,22 @@ PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
|
|
|
|
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
|
|
|
|
"to create a Cifar100Node")
|
|
|
|
"to create a Cifar100Node")
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler,
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
|
|
|
std::shared_ptr<CacheClient> cc) {
|
|
|
|
auto cifar100 =
|
|
|
|
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler),
|
|
|
|
std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
|
|
|
toDatasetCache(std::move(cc)));
|
|
|
|
|
|
|
|
THROW_IF_ERROR(cifar100->ValidateParams());
|
|
|
|
THROW_IF_ERROR(cifar100->ValidateParams());
|
|
|
|
return cifar100;
|
|
|
|
return cifar100;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
|
}));
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
|
|
PYBIND_REGISTER(
|
|
|
|
PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
|
|
|
|
CLUENode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode",
|
|
|
|
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode")
|
|
|
|
"to create a CLUENode")
|
|
|
|
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle,
|
|
|
|
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples,
|
|
|
|
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) {
|
|
|
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
|
|
|
std::shared_ptr<CLUENode> clue_node =
|
|
|
|
std::shared_ptr<CLUENode> clue_node =
|
|
|
|
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle),
|
|
|
|
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples,
|
|
|
|
num_shards, shard_id, toDatasetCache(std::move(cc)));
|
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
|
|
|
THROW_IF_ERROR(clue_node->ValidateParams());
|
|
|
|
THROW_IF_ERROR(clue_node->ValidateParams());
|
|
|
|
return clue_node;
|
|
|
|
return clue_node;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -159,10 +160,9 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
|
|
|
|
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
|
|
|
|
"to create a CocoNode")
|
|
|
|
"to create a CocoNode")
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
|
|
|
|
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) {
|
|
|
|
bool decode, py::handle sampler) {
|
|
|
|
std::shared_ptr<CocoNode> coco =
|
|
|
|
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
|
|
|
|
std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, toSamplerObj(sampler),
|
|
|
|
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr);
|
|
|
|
toDatasetCache(std::move(cc)));
|
|
|
|
|
|
|
|
THROW_IF_ERROR(coco->ValidateParams());
|
|
|
|
THROW_IF_ERROR(coco->ValidateParams());
|
|
|
|
return coco;
|
|
|
|
return coco;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -172,10 +172,10 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
|
|
|
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
|
|
|
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
|
|
|
|
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
|
|
|
|
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
|
|
|
|
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
|
|
|
|
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) {
|
|
|
|
int32_t num_shards, int32_t shard_id) {
|
|
|
|
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults),
|
|
|
|
auto csv =
|
|
|
|
column_names, num_samples, toShuffleMode(shuffle),
|
|
|
|
std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names,
|
|
|
|
num_shards, shard_id, toDatasetCache(std::move(cc)));
|
|
|
|
num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
|
|
|
THROW_IF_ERROR(csv->ValidateParams());
|
|
|
|
THROW_IF_ERROR(csv->ValidateParams());
|
|
|
|
return csv;
|
|
|
|
return csv;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -205,12 +205,12 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
|
|
|
|
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
|
|
|
|
*m, "ImageFolderNode", "to create an ImageFolderNode")
|
|
|
|
*m, "ImageFolderNode", "to create an ImageFolderNode")
|
|
|
|
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions,
|
|
|
|
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions,
|
|
|
|
py::dict class_indexing, std::shared_ptr<CacheClient> cc) {
|
|
|
|
py::dict class_indexing) {
|
|
|
|
// Don't update recursive to true
|
|
|
|
// Don't update recursive to true
|
|
|
|
bool recursive = false; // Will be removed in future PR
|
|
|
|
bool recursive = false; // Will be removed in future PR
|
|
|
|
auto imagefolder = std::make_shared<ImageFolderNode>(
|
|
|
|
auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler),
|
|
|
|
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions),
|
|
|
|
recursive, toStringSet(extensions),
|
|
|
|
toStringMap(class_indexing), toDatasetCache(std::move(cc)));
|
|
|
|
toStringMap(class_indexing), nullptr);
|
|
|
|
THROW_IF_ERROR(imagefolder->ValidateParams());
|
|
|
|
THROW_IF_ERROR(imagefolder->ValidateParams());
|
|
|
|
return imagefolder;
|
|
|
|
return imagefolder;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -220,10 +220,9 @@ PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
|
|
|
|
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
|
|
|
|
"to create a ManifestNode")
|
|
|
|
"to create a ManifestNode")
|
|
|
|
.def(py::init([](std::string dataset_file, std::string usage, py::handle sampler,
|
|
|
|
.def(py::init([](std::string dataset_file, std::string usage, py::handle sampler,
|
|
|
|
py::dict class_indexing, bool decode, std::shared_ptr<CacheClient> cc) {
|
|
|
|
py::dict class_indexing, bool decode) {
|
|
|
|
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
|
|
|
|
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
|
|
|
|
toStringMap(class_indexing), decode,
|
|
|
|
toStringMap(class_indexing), decode, nullptr);
|
|
|
|
toDatasetCache(std::move(cc)));
|
|
|
|
|
|
|
|
THROW_IF_ERROR(manifest->ValidateParams());
|
|
|
|
THROW_IF_ERROR(manifest->ValidateParams());
|
|
|
|
return manifest;
|
|
|
|
return manifest;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -261,28 +260,25 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
|
|
|
|
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
|
|
|
|
"to create an MnistNode")
|
|
|
|
"to create an MnistNode")
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler,
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
|
|
|
std::shared_ptr<CacheClient> cc) {
|
|
|
|
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
|
|
|
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler),
|
|
|
|
|
|
|
|
toDatasetCache(std::move(cc)));
|
|
|
|
|
|
|
|
THROW_IF_ERROR(mnist->ValidateParams());
|
|
|
|
THROW_IF_ERROR(mnist->ValidateParams());
|
|
|
|
return mnist;
|
|
|
|
return mnist;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
|
}));
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
|
|
PYBIND_REGISTER(
|
|
|
|
PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
|
|
|
|
RandomNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode",
|
|
|
|
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
|
|
|
|
"to create a RandomNode")
|
|
|
|
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list,
|
|
|
|
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list) {
|
|
|
|
std::shared_ptr<CacheClient> cc) {
|
|
|
|
|
|
|
|
auto random_node =
|
|
|
|
auto random_node =
|
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
|
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
|
|
|
|
THROW_IF_ERROR(random_node->ValidateParams());
|
|
|
|
THROW_IF_ERROR(random_node->ValidateParams());
|
|
|
|
return random_node;
|
|
|
|
return random_node;
|
|
|
|
}))
|
|
|
|
}))
|
|
|
|
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list, std::shared_ptr<CacheClient> cc) {
|
|
|
|
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) {
|
|
|
|
auto random_node =
|
|
|
|
auto random_node =
|
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
|
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
|
|
|
|
THROW_IF_ERROR(random_node->ValidateParams());
|
|
|
|
THROW_IF_ERROR(random_node->ValidateParams());
|
|
|
|
return random_node;
|
|
|
|
return random_node;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -292,10 +288,10 @@ PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
|
|
|
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
|
|
|
"to create a TextFileNode")
|
|
|
|
"to create a TextFileNode")
|
|
|
|
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
|
|
|
|
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
|
|
|
|
int32_t shard_id, std::shared_ptr<CacheClient> cc) {
|
|
|
|
int32_t shard_id) {
|
|
|
|
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>(
|
|
|
|
std::shared_ptr<TextFileNode> textfile_node =
|
|
|
|
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id,
|
|
|
|
std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples,
|
|
|
|
toDatasetCache(std::move(cc)));
|
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
|
|
|
THROW_IF_ERROR(textfile_node->ValidateParams());
|
|
|
|
THROW_IF_ERROR(textfile_node->ValidateParams());
|
|
|
|
return textfile_node;
|
|
|
|
return textfile_node;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -306,19 +302,19 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
|
|
|
|
"to create a TFRecordNode")
|
|
|
|
"to create a TFRecordNode")
|
|
|
|
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, py::list columns_list,
|
|
|
|
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, py::list columns_list,
|
|
|
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
|
|
|
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
|
|
|
|
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) {
|
|
|
|
bool shard_equal_rows) {
|
|
|
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
|
|
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
|
|
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
|
|
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
|
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
|
|
|
|
THROW_IF_ERROR(tfrecord->ValidateParams());
|
|
|
|
THROW_IF_ERROR(tfrecord->ValidateParams());
|
|
|
|
return tfrecord;
|
|
|
|
return tfrecord;
|
|
|
|
}))
|
|
|
|
}))
|
|
|
|
.def(py::init([](py::list dataset_files, std::string schema, py::list columns_list,
|
|
|
|
.def(py::init([](py::list dataset_files, std::string schema, py::list columns_list,
|
|
|
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
|
|
|
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
|
|
|
|
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) {
|
|
|
|
bool shard_equal_rows) {
|
|
|
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
|
|
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
|
|
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
|
|
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
|
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
|
|
|
|
THROW_IF_ERROR(tfrecord->ValidateParams());
|
|
|
|
THROW_IF_ERROR(tfrecord->ValidateParams());
|
|
|
|
return tfrecord;
|
|
|
|
return tfrecord;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -326,12 +322,10 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
|
|
|
|
|
|
|
|
|
|
|
|
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
|
|
|
|
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
|
|
|
|
.def(
|
|
|
|
.def(py::init([](std::string dataset_dir, std::string task, std::string usage,
|
|
|
|
py::init([](std::string dataset_dir, std::string task, std::string usage, py::dict class_indexing,
|
|
|
|
py::dict class_indexing, bool decode, py::handle sampler) {
|
|
|
|
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) {
|
|
|
|
std::shared_ptr<VOCNode> voc = std::make_shared<VOCNode>(
|
|
|
|
std::shared_ptr<VOCNode> voc =
|
|
|
|
dataset_dir, task, usage, toStringMap(class_indexing), decode, toSamplerObj(sampler), nullptr);
|
|
|
|
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
|
|
|
|
|
|
|
|
toSamplerObj(sampler), toDatasetCache(std::move(cc)));
|
|
|
|
|
|
|
|
THROW_IF_ERROR(voc->ValidateParams());
|
|
|
|
THROW_IF_ERROR(voc->ValidateParams());
|
|
|
|
return voc;
|
|
|
|
return voc;
|
|
|
|
}));
|
|
|
|
}));
|
|
|
@ -439,11 +433,11 @@ PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
|
|
|
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
|
|
|
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
|
|
|
|
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
|
|
|
|
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
|
|
|
|
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
|
|
|
|
py::list output_columns, py::list project_columns, std::shared_ptr<CacheClient> cc,
|
|
|
|
py::list output_columns, py::list project_columns,
|
|
|
|
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
|
|
|
|
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
|
|
|
|
auto map = std::make_shared<MapNode>(
|
|
|
|
auto map = std::make_shared<MapNode>(
|
|
|
|
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
|
|
|
|
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
|
|
|
|
toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)),
|
|
|
|
toStringVector(output_columns), toStringVector(project_columns), nullptr,
|
|
|
|
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()));
|
|
|
|
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()));
|
|
|
|
THROW_IF_ERROR(map->ValidateParams());
|
|
|
|
THROW_IF_ERROR(map->ValidateParams());
|
|
|
|
return map;
|
|
|
|
return map;
|
|
|
|