parent
7ec0b5857a
commit
82103a693d
@ -1,16 +1,29 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
if (ENABLE_PYTHON)
|
||||
add_library(APItoPython OBJECT
|
||||
de_pipeline.cc
|
||||
python_bindings.cc
|
||||
)
|
||||
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
endif()
|
||||
add_library(APItoPython OBJECT
|
||||
python/de_pipeline.cc
|
||||
python/pybind_register.cc
|
||||
python/bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/kernels/data/bindings.cc
|
||||
python/bindings/dataset/kernels/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
||||
python/bindings/dataset/engine/gnn/bindings.cc
|
||||
python/bindings/dataset/kernels/image/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
|
||||
python/bindings/dataset/text/bindings.cc
|
||||
python/bindings/dataset/text/kernels/bindings.cc
|
||||
python/bindings/mindrecord/include/bindings.cc
|
||||
)
|
||||
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
endif ()
|
||||
|
||||
add_library(cpp-API OBJECT
|
||||
datasets.cc
|
||||
iterator.cc
|
||||
transforms.cc
|
||||
samplers.cc
|
||||
)
|
||||
datasets.cc
|
||||
iterator.cc
|
||||
transforms.cc
|
||||
samplers.cc
|
||||
)
|
||||
|
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
DEPipeline, 0, ([](const py::module *m) {
|
||||
(void)py::class_<DEPipeline>(*m, "DEPipeline")
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"AddNodeToTree",
|
||||
[](DEPipeline &de, const OpName &op_name, const py::dict &args) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
|
||||
return out;
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def_static("AddChildToParentNode",
|
||||
[](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
|
||||
THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
|
||||
})
|
||||
.def("AssignRootNode",
|
||||
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
|
||||
.def("SetBatchParameters",
|
||||
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
|
||||
.def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
|
||||
.def("GetNextAsMap",
|
||||
[](DEPipeline &de) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(de.GetNextAsMap(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetNextAsList",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetNextAsList(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetOutputShapes",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetOutputShapes(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetOutputTypes",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetOutputTypes(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
|
||||
.def("GetBatchSize", &DEPipeline::GetBatchSize)
|
||||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
||||
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
|
||||
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
||||
return true;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(OpName, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<OpName>(*m, "OpName", py::arithmetic())
|
||||
.value("SHUFFLE", OpName::kShuffle)
|
||||
.value("BATCH", OpName::kBatch)
|
||||
.value("BUCKETBATCH", OpName::kBucketBatch)
|
||||
.value("BARRIER", OpName::kBarrier)
|
||||
.value("MINDRECORD", OpName::kMindrecord)
|
||||
.value("CACHE", OpName::kCache)
|
||||
.value("REPEAT", OpName::kRepeat)
|
||||
.value("SKIP", OpName::kSkip)
|
||||
.value("TAKE", OpName::kTake)
|
||||
.value("ZIP", OpName::kZip)
|
||||
.value("CONCAT", OpName::kConcat)
|
||||
.value("MAP", OpName::kMap)
|
||||
.value("FILTER", OpName::kFilter)
|
||||
.value("DEVICEQUEUE", OpName::kDeviceQueue)
|
||||
.value("GENERATOR", OpName::kGenerator)
|
||||
.export_values()
|
||||
.value("RENAME", OpName::kRename)
|
||||
.value("TFREADER", OpName::kTfReader)
|
||||
.value("PROJECT", OpName::kProject)
|
||||
.value("IMAGEFOLDER", OpName::kImageFolder)
|
||||
.value("MNIST", OpName::kMnist)
|
||||
.value("MANIFEST", OpName::kManifest)
|
||||
.value("VOC", OpName::kVoc)
|
||||
.value("COCO", OpName::kCoco)
|
||||
.value("CIFAR10", OpName::kCifar10)
|
||||
.value("CIFAR100", OpName::kCifar100)
|
||||
.value("RANDOMDATA", OpName::kRandomData)
|
||||
.value("BUILDVOCAB", OpName::kBuildVocab)
|
||||
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
|
||||
.value("CELEBA", OpName::kCelebA)
|
||||
.value("TEXTFILE", OpName::kTextFile)
|
||||
.value("EPOCHCTRL", OpName::kEpochCtrl)
|
||||
.value("CSV", OpName::kCsv)
|
||||
.value("CLUE", OpName::kClue);
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,114 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(GlobalContext, 0, ([](const py::module *m) {
|
||||
(void)py::class_<GlobalContext>(*m, "GlobalContext")
|
||||
.def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
|
||||
(void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
|
||||
.def("__str__", &ConfigManager::ToString)
|
||||
.def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
|
||||
.def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
|
||||
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
|
||||
.def("set_op_connector_size", &ConfigManager::set_op_connector_size)
|
||||
.def("set_seed", &ConfigManager::set_seed)
|
||||
.def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
|
||||
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
|
||||
.def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
|
||||
.def("get_worker_connector_size", &ConfigManager::worker_connector_size)
|
||||
.def("get_op_connector_size", &ConfigManager::op_connector_size)
|
||||
.def("get_seed", &ConfigManager::seed)
|
||||
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
|
||||
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) {
|
||||
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
|
||||
.def(py::init([](py::array arr) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out));
|
||||
return out;
|
||||
}))
|
||||
.def_buffer([](Tensor &tensor) {
|
||||
py::buffer_info info;
|
||||
THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
|
||||
return info;
|
||||
})
|
||||
.def("__str__", &Tensor::ToString)
|
||||
.def("shape", &Tensor::shape)
|
||||
.def("type", &Tensor::type)
|
||||
.def("as_array", [](py::object &t) {
|
||||
auto &tensor = py::cast<Tensor &>(t);
|
||||
if (tensor.type() == DataType::DE_STRING) {
|
||||
py::array res;
|
||||
tensor.GetDataAsNumpyStrings(&res);
|
||||
return res;
|
||||
}
|
||||
py::buffer_info info;
|
||||
THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
|
||||
return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t);
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TensorShape, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TensorShape>(*m, "TensorShape")
|
||||
.def(py::init<py::list>())
|
||||
.def("__str__", &TensorShape::ToString)
|
||||
.def("as_list", &TensorShape::AsPyList)
|
||||
.def("is_known", &TensorShape::known);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DataType, 0, ([](const py::module *m) {
|
||||
(void)py::class_<DataType>(*m, "DataType")
|
||||
.def(py::init<std::string>())
|
||||
.def(py::self == py::self)
|
||||
.def("__str__", &DataType::ToString)
|
||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BorderType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<BorderType>(*m, "BorderType", py::arithmetic())
|
||||
.value("DE_BORDER_CONSTANT", BorderType::kConstant)
|
||||
.value("DE_BORDER_EDGE", BorderType::kEdge)
|
||||
.value("DE_BORDER_REFLECT", BorderType::kReflect)
|
||||
.value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<InterpolationMode>(*m, "InterpolationMode", py::arithmetic())
|
||||
.value("DE_INTER_LINEAR", InterpolationMode::kLinear)
|
||||
.value("DE_INTER_CUBIC", InterpolationMode::kCubic)
|
||||
.value("DE_INTER_AREA", InterpolationMode::kArea)
|
||||
.value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
|
||||
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
|
||||
.def(py::init<uint32_t, uint64_t, bool>());
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(CBatchInfo, 0, ([](const py::module *m) {
|
||||
(void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
|
||||
.def(py::init<int64_t, int64_t, int64_t>())
|
||||
.def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
|
||||
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DatasetOp, 0, ([](const py::module *m) {
|
||||
(void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(*m, "DatasetOp");
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,186 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ClueOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
|
||||
.def_static("get_num_rows", [](const py::list &files) {
|
||||
int64_t count = 0;
|
||||
std::vector<std::string> filenames;
|
||||
for (auto file : files) {
|
||||
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
|
||||
}
|
||||
THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CsvOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp")
|
||||
.def_static("get_num_rows", [](const py::list &files, bool csv_header) {
|
||||
int64_t count = 0;
|
||||
std::vector<std::string> filenames;
|
||||
for (auto file : files) {
|
||||
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
|
||||
}
|
||||
THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
PYBIND_REGISTER(CocoOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<CocoOp, DatasetOp, std::shared_ptr<CocoOp>>(*m, "CocoOp")
|
||||
.def_static("get_class_indexing",
|
||||
[](const std::string &dir, const std::string &file, const std::string &task) {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
||||
THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
})
|
||||
.def_static("get_num_rows",
|
||||
[](const std::string &dir, const std::string &file, const std::string &task) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
|
||||
.def_static("get_num_rows_and_classes", [](const std::string &path) {
|
||||
int64_t count = 0, num_classes = 0;
|
||||
THROW_IF_ERROR(
|
||||
ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
|
||||
return py::make_tuple(count, num_classes);
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
||||
.def_static("get_num_rows_and_classes",
|
||||
[](const std::string &file, const py::dict &dict, const std::string &usage) {
|
||||
int64_t count = 0, num_classes = 0;
|
||||
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
|
||||
return py::make_tuple(count, num_classes);
|
||||
})
|
||||
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict,
|
||||
const std::string &usage) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
}));
|
||||
PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
||||
.def_static("get_num_rows", [](const std::vector<std::string> &paths, bool load_dataset,
|
||||
const py::object &sampler, const int64_t num_padded) {
|
||||
int64_t count = 0;
|
||||
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||
if (py::hasattr(sampler, "create_for_minddataset")) {
|
||||
auto create = sampler.attr("create_for_minddataset");
|
||||
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
}
|
||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TextFileOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
|
||||
.def_static("get_num_rows", [](const py::list &files) {
|
||||
int64_t count = 0;
|
||||
std::vector<std::string> filenames;
|
||||
for (auto file : files) {
|
||||
!file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
|
||||
}
|
||||
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
|
||||
.def_static(
|
||||
"get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) {
|
||||
int64_t count = 0;
|
||||
std::vector<std::string> filenames;
|
||||
for (auto l : files) {
|
||||
!l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back("");
|
||||
}
|
||||
THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate));
|
||||
return count;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
|
||||
.def_static("get_num_rows",
|
||||
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
|
||||
return count;
|
||||
})
|
||||
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
|
||||
const std::string &task_mode, const py::dict &dict) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) {
|
||||
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
|
||||
.def("set_num_rows",
|
||||
[](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
|
||||
.def("set_num_samples",
|
||||
[](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
|
||||
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
|
||||
.def("get_indices",
|
||||
[](Sampler &self) {
|
||||
py::array ret;
|
||||
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
||||
return ret;
|
||||
})
|
||||
.def("add_child", [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) {
|
||||
THROW_IF_ERROR(self->AddChild(child));
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(
|
||||
*m, "DistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
|
||||
.def(py::init<int64_t, int64_t, bool>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
|
||||
.def(py::init<int64_t, py::object>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
|
||||
.def(py::init<int64_t, bool, bool>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m,
|
||||
"SequentialSampler")
|
||||
.def(py::init<int64_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(
|
||||
*m, "SubsetRandomSampler")
|
||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(
|
||||
*m, "WeightedRandomSampler")
|
||||
.def(py::init<int64_t, std::vector<double>, bool>());
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
|
||||
#include "minddata/dataset/engine/gnn/graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
Graph, 0, ([](const py::module *m) {
|
||||
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
|
||||
.def(py::init([](std::string dataset_file, int32_t num_workers) {
|
||||
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
|
||||
THROW_IF_ERROR(g_out->Init());
|
||||
return g_out;
|
||||
}))
|
||||
.def("get_all_nodes",
|
||||
[](gnn::Graph &g, gnn::NodeType node_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_edges",
|
||||
[](gnn::Graph &g, gnn::EdgeType edge_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_nodes_from_edges",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_sampled_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
|
||||
std::vector<gnn::NodeType> neighbor_types) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_neg_sampled_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
|
||||
gnn::NodeType neg_neighbor_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_node_feature",
|
||||
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
|
||||
TensorRow out;
|
||||
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
||||
return out.getRow();
|
||||
})
|
||||
.def("get_edge_feature",
|
||||
[](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
|
||||
TensorRow out;
|
||||
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
|
||||
return out.getRow();
|
||||
})
|
||||
.def("graph_info",
|
||||
[](gnn::Graph &g) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(g.GraphInfo(&out));
|
||||
return out;
|
||||
})
|
||||
.def("random_walk",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
|
||||
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
|
||||
return out;
|
||||
});
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h"
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status PyListToTensorOps(const py::list &py_ops, std::vector<std::shared_ptr<TensorOp>> *ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ops);
|
||||
for (auto op : py_ops) {
|
||||
if (py::isinstance<TensorOp>(op)) {
|
||||
ops->emplace_back(op.cast<std::shared_ptr<TensorOp>>());
|
||||
} else if (py::isinstance<py::function>(op)) {
|
||||
ops->emplace_back(std::make_shared<PyFuncOp>(op.cast<py::function>()));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("element is neither a TensorOp nor a pyfunc.");
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
|
||||
for (auto const &op : *ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(op);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PYBIND_REGISTER(TensorOp, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
|
||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ComposeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
return std::make_shared<ComposeOp>(t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NoOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(
|
||||
*m, "NoOp", "TensorOp that does nothing, for testing purposes only.")
|
||||
.def(py::init<>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomChoiceOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
return std::make_shared<RandomChoiceOp>(t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomApplyOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
|
||||
.def(py::init([](double prob, const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
if (prob < 0 || prob > 1) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
|
||||
}
|
||||
return std::make_shared<RandomApplyOp>(prob, t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,133 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/kernels/data/concatenate_op.h"
|
||||
#include "minddata/dataset/kernels/data/duplicate_op.h"
|
||||
#include "minddata/dataset/kernels/data/fill_op.h"
|
||||
#include "minddata/dataset/kernels/data/mask_op.h"
|
||||
#include "minddata/dataset/kernels/data/one_hot_op.h"
|
||||
#include "minddata/dataset/kernels/data/pad_end_op.h"
|
||||
#include "minddata/dataset/kernels/data/slice_op.h"
|
||||
#include "minddata/dataset/kernels/data/to_float16_op.h"
|
||||
#include "minddata/dataset/kernels/data/type_cast_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(ConcatenateOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(
|
||||
*m, "ConcatenateOp", "Tensor operation concatenate tensors.")
|
||||
.def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"),
|
||||
py::arg("prepend").none(true), py::arg("append").none(true));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DuplicateOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp",
|
||||
"Duplicate tensor.")
|
||||
.def(py::init<>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(FillOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
|
||||
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
|
||||
.def(py::init<std::shared_ptr<Tensor>>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MaskOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(
|
||||
*m, "MaskOp", "Tensor mask operation using relational comparator")
|
||||
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(OneHotOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
|
||||
*m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
|
||||
.def(py::init<int32_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
|
||||
*m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
|
||||
.def(py::init<TensorShape, std::shared_ptr<Tensor>>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SliceOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp",
|
||||
"Tensor slice operation.")
|
||||
.def(py::init<bool>())
|
||||
.def(py::init([](const py::list &py_list) {
|
||||
std::vector<dsize_t> c_list;
|
||||
for (auto l : py_list) {
|
||||
if (!l.is_none()) {
|
||||
c_list.push_back(py::reinterpret_borrow<py::int_>(l));
|
||||
}
|
||||
}
|
||||
return std::make_shared<SliceOp>(c_list);
|
||||
}))
|
||||
.def(py::init([](const py::tuple &py_slice) {
|
||||
if (py_slice.size() != 3) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
|
||||
}
|
||||
Slice c_slice;
|
||||
if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) {
|
||||
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]),
|
||||
py::reinterpret_borrow<py::int_>(py_slice[1]),
|
||||
py::reinterpret_borrow<py::int_>(py_slice[2]));
|
||||
} else if (py_slice[0].is_none() && py_slice[2].is_none()) {
|
||||
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1]));
|
||||
} else if (!py_slice[0].is_none() && !py_slice[1].is_none()) {
|
||||
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]),
|
||||
py::reinterpret_borrow<py::int_>(py_slice[1]));
|
||||
}
|
||||
|
||||
if (!c_slice.valid()) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
|
||||
}
|
||||
return std::make_shared<SliceOp>(c_slice);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ToFloat16Op, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ToFloat16Op, TensorOp, std::shared_ptr<ToFloat16Op>>(
|
||||
*m, "ToFloat16Op", py::dynamic_attr(),
|
||||
"Tensor operator to type cast float32 data to a float16 type.")
|
||||
.def(py::init<>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TypeCastOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
|
||||
*m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
|
||||
.def(py::init<DataType>(), py::arg("data_type"))
|
||||
.def(py::init<std::string>(), py::arg("data_type"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
|
||||
.value("EQ", RelationalOp::kEqual)
|
||||
.value("NE", RelationalOp::kNotEqual)
|
||||
.value("LT", RelationalOp::kLess)
|
||||
.value("LE", RelationalOp::kLessEqual)
|
||||
.value("GT", RelationalOp::kGreater)
|
||||
.value("GE", RelationalOp::kGreaterEqual)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,92 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/text/vocab.h"
|
||||
#include "minddata/dataset/text/sentence_piece_vocab.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(Vocab, 0, ([](const py::module *m) {
|
||||
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
|
||||
.def(py::init<>())
|
||||
.def_static("from_list",
|
||||
[](const py::list &words, const py::list &special_tokens, bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static(
|
||||
"from_file",
|
||||
[](const std::string &path, const std::string &dlm, int32_t vocab_size,
|
||||
const py::list &special_tokens, bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("from_dict", [](const py::dict &words) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v));
|
||||
return v;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SentencePieceVocab, 0, ([](const py::module *m) {
|
||||
(void)py::class_<SentencePieceVocab, std::shared_ptr<SentencePieceVocab>>(*m, "SentencePieceVocab")
|
||||
.def(py::init<>())
|
||||
.def_static("from_file",
|
||||
[](const py::list &paths, const int vocab_size, const float character_coverage,
|
||||
const SentencePieceModel model_type, const py::dict ¶ms) {
|
||||
std::shared_ptr<SentencePieceVocab> v;
|
||||
std::vector<std::string> path_list;
|
||||
for (auto path : paths) {
|
||||
path_list.emplace_back(py::str(path));
|
||||
}
|
||||
std::unordered_map<std::string, std::string> param_map;
|
||||
for (auto param : params) {
|
||||
std::string key = py::reinterpret_borrow<py::str>(param.first);
|
||||
if (key == "input" || key == "vocab_size" || key == "model_prefix" ||
|
||||
key == "character_coverage" || key == "model_type") {
|
||||
continue;
|
||||
}
|
||||
param_map[key] = py::reinterpret_borrow<py::str>(param.second);
|
||||
}
|
||||
THROW_IF_ERROR(SentencePieceVocab::BuildFromFile(
|
||||
path_list, vocab_size, character_coverage, model_type, param_map, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("save_model", [](const std::shared_ptr<SentencePieceVocab> *vocab, std::string path,
|
||||
std::string filename) {
|
||||
THROW_IF_ERROR(SentencePieceVocab::SaveModel(vocab, path, filename));
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SentencePieceModel, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<SentencePieceModel>(*m, "SentencePieceModel", py::arithmetic())
|
||||
.value("DE_SENTENCE_PIECE_UNIGRAM", SentencePieceModel::kUnigram)
|
||||
.value("DE_SENTENCE_PIECE_BPE", SentencePieceModel::kBpe)
|
||||
.value("DE_SENTENCE_PIECE_CHAR", SentencePieceModel::kChar)
|
||||
.value("DE_SENTENCE_PIECE_WORD", SentencePieceModel::kWord)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,244 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
|
||||
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/lookup_op.h"
|
||||
#include "minddata/dataset/text/kernels/ngram_op.h"
|
||||
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||
#include "minddata/dataset/text/kernels/to_number_op.h"
|
||||
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||
|
||||
#ifdef ENABLE_ICU4C
|
||||
#include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/case_fold_op.h"
|
||||
#include "minddata/dataset/text/kernels/normalize_utf8_op.h"
|
||||
#include "minddata/dataset/text/kernels/regex_replace_op.h"
|
||||
#include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
#ifdef ENABLE_ICU4C
|
||||
|
||||
PYBIND_REGISTER(BasicTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<BasicTokenizerOp, TensorOp, std::shared_ptr<BasicTokenizerOp>>(
|
||||
*m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.")
|
||||
.def(py::init<const bool &, const bool &, const NormalizeForm &, const bool &, const bool &>(),
|
||||
py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
|
||||
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
|
||||
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
|
||||
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
|
||||
py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(WhitespaceTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<WhitespaceTokenizerOp, TensorOp, std::shared_ptr<WhitespaceTokenizerOp>>(
|
||||
*m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.")
|
||||
.def(py::init<const bool &>(), py::arg(" with_offsets ") = WhitespaceTokenizerOp::kDefWithOffsets);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(UnicodeScriptTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<UnicodeScriptTokenizerOp, TensorOp, std::shared_ptr<UnicodeScriptTokenizerOp>>(
|
||||
*m, "UnicodeScriptTokenizerOp",
|
||||
"Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const bool &, const bool &>(),
|
||||
py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace,
|
||||
py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CaseFoldOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<CaseFoldOp, TensorOp, std::shared_ptr<CaseFoldOp>>(
|
||||
*m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor")
|
||||
.def(py::init<>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NormalizeUTF8Op, 1, ([](const py::module *m) {
|
||||
(void)py::class_<NormalizeUTF8Op, TensorOp, std::shared_ptr<NormalizeUTF8Op>>(
|
||||
*m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.")
|
||||
.def(py::init<>())
|
||||
.def(py::init<NormalizeForm>(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RegexReplaceOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RegexReplaceOp, TensorOp, std::shared_ptr<RegexReplaceOp>>(
|
||||
*m, "RegexReplaceOp",
|
||||
"Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.")
|
||||
.def(py::init<const std::string &, const std::string &, bool>(), py::arg("pattern"),
|
||||
py::arg("replace"), py::arg("replace_all"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RegexTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RegexTokenizerOp, TensorOp, std::shared_ptr<RegexTokenizerOp>>(
|
||||
*m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.")
|
||||
.def(py::init<const std::string &, const std::string &, const bool &>(), py::arg("delim_pattern"),
|
||||
py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets);
|
||||
}));
|
||||
PYBIND_REGISTER(BertTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<BertTokenizerOp, TensorOp, std::shared_ptr<BertTokenizerOp>>(
|
||||
*m, "BertTokenizerOp", "Tokenizer used for Bert text process.")
|
||||
.def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &,
|
||||
const bool &, const bool &, const NormalizeForm &, const bool &, const bool &>(),
|
||||
py::arg("vocab"),
|
||||
py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
|
||||
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
|
||||
py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
|
||||
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
|
||||
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
|
||||
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
|
||||
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NormalizeForm, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<NormalizeForm>(*m, "NormalizeForm", py::arithmetic())
|
||||
.value("DE_NORMALIZE_NONE", NormalizeForm::kNone)
|
||||
.value("DE_NORMALIZE_NFC", NormalizeForm::kNfc)
|
||||
.value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc)
|
||||
.value("DE_NORMALIZE_NFD", NormalizeForm::kNfd)
|
||||
.value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
#endif
|
||||
|
||||
PYBIND_REGISTER(JiebaTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(
|
||||
*m, "JiebaTokenizerOp", "")
|
||||
.def(py::init<const std::string &, const std::string &, const JiebaMode &, const bool &>(),
|
||||
py::arg("hmm_path"), py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix,
|
||||
py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets)
|
||||
.def("add_word", [](JiebaTokenizerOp &self, const std::string word, int freq) {
|
||||
THROW_IF_ERROR(self.AddWord(word, freq));
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>(
|
||||
*m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
|
||||
.def(py::init<const bool &>(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(
|
||||
*m, "LookupOp", "Tensor operation to LookUp each word.")
|
||||
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) {
|
||||
if (vocab == nullptr) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
|
||||
}
|
||||
if (py_word.is_none()) {
|
||||
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists);
|
||||
}
|
||||
std::string word = py::reinterpret_borrow<py::str>(py_word);
|
||||
WordIdType default_id = vocab->Lookup(word);
|
||||
if (default_id == Vocab::kNoTokenExists) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError,
|
||||
"default unknown token: " + word + " doesn't exist in vocab."));
|
||||
}
|
||||
return std::make_shared<LookupOp>(vocab, default_id);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NgramOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp",
|
||||
"TensorOp performs ngram mapping.")
|
||||
.def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &,
|
||||
const std::string &, const std::string &>(),
|
||||
py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"),
|
||||
py::arg("r_pad_token"), py::arg("separator"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
WordpieceTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<WordpieceTokenizerOp, TensorOp, std::shared_ptr<WordpieceTokenizerOp>>(
|
||||
*m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.")
|
||||
.def(
|
||||
py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, const bool &>(),
|
||||
py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
|
||||
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
|
||||
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SlidingWindowOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
|
||||
*m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
|
||||
.def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
SentencePieceTokenizerOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SentencePieceTokenizerOp, TensorOp, std::shared_ptr<SentencePieceTokenizerOp>>(
|
||||
*m, "SentencePieceTokenizerOp", "Tokenize scalar token or 1-D tokens to tokens by sentence piece.")
|
||||
.def(
|
||||
py::init<std::shared_ptr<SentencePieceVocab> &, const SPieceTokenizerLoadType, const SPieceTokenizerOutType>(),
|
||||
py::arg("vocab"), py::arg("load_type") = SPieceTokenizerLoadType::kModel,
|
||||
py::arg("out_type") = SPieceTokenizerOutType::kString)
|
||||
.def(py::init<const std::string &, const std::string &, const SPieceTokenizerLoadType,
|
||||
const SPieceTokenizerOutType>(),
|
||||
py::arg("model_path"), py::arg("model_filename"), py::arg("load_type") = SPieceTokenizerLoadType::kFile,
|
||||
py::arg("out_type") = SPieceTokenizerOutType::kString);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ToNumberOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(
|
||||
*m, "ToNumberOp", "TensorOp to convert strings to numbers.")
|
||||
.def(py::init<DataType>(), py::arg("data_type"))
|
||||
.def(py::init<std::string>(), py::arg("data_type"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TruncateSequencePairOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<TruncateSequencePairOp, TensorOp, std::shared_ptr<TruncateSequencePairOp>>(
|
||||
*m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length")
|
||||
.def(py::init<int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(JiebaMode, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<JiebaMode>(*m, "JiebaMode", py::arithmetic())
|
||||
.value("DE_JIEBA_MIX", JiebaMode::kMix)
|
||||
.value("DE_JIEBA_MP", JiebaMode::kMp)
|
||||
.value("DE_JIEBA_HMM", JiebaMode::kHmm)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SPieceTokenizerOutType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<SPieceTokenizerOutType>(*m, "SPieceTokenizerOutType", py::arithmetic())
|
||||
.value("DE_SPIECE_TOKENIZER_OUTTYPE_KString", SPieceTokenizerOutType::kString)
|
||||
.value("DE_SPIECE_TOKENIZER_OUTTYPE_KINT", SPieceTokenizerOutType::kInt)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SPieceTokenizerLoadType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<SPieceTokenizerLoadType>(*m, "SPieceTokenizerLoadType", py::arithmetic())
|
||||
.value("DE_SPIECE_TOKENIZER_LOAD_KFILE", SPieceTokenizerLoadType::kFile)
|
||||
.value("DE_SPIECE_TOKENIZER_LOAD_KMODEL", SPieceTokenizerLoadType::kModel)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,87 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#include "minddata/mindrecord/include/shard_pk_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_sequential_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
|
||||
*m, "ShardOperator")
|
||||
.def("add_child",
|
||||
[](std::shared_ptr<mindrecord::ShardOperator> self,
|
||||
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m,
|
||||
"MindrecordDistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
ShardPkSample, 1, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
||||
*m, "MindrecordPkSampler")
|
||||
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
|
||||
if (shuffle == true) {
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed());
|
||||
} else {
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
|
||||
}
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
ShardSample, 0, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
|
||||
*m, "MindrecordSubsetRandomSampler")
|
||||
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m,
|
||||
"MindrecordSequentialSampler")
|
||||
.def(py::init([](int num_samples, int start_index) {
|
||||
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
ShardShuffle, 1, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
|
||||
*m, "MindrecordRandomSampler")
|
||||
.def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
|
||||
return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
|
||||
}));
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PybindDefinedFunctionRegister &PybindDefinedFunctionRegister::GetSingleton() {
|
||||
static PybindDefinedFunctionRegister instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// This is where we externalize the C logic as python modules
|
||||
PYBIND11_MODULE(_c_dataengine, m) {
|
||||
m.doc() = "pybind11 for _c_dataengine";
|
||||
|
||||
auto all_fns = mindspore::dataset::PybindDefinedFunctionRegister::AllFunctions();
|
||||
|
||||
for (auto &item : all_fns) {
|
||||
for (auto &func : item.second) {
|
||||
func.second(&m);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef API_PYBIND_API_H_
|
||||
#define API_PYBIND_API_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
|
||||
namespace dataset {
|
||||
#define THROW_IF_ERROR(s) \
|
||||
do { \
|
||||
Status rc = std::move(s); \
|
||||
if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
|
||||
} while (false)
|
||||
|
||||
using PybindDefineFunc = std::function<void(py::module *)>;
|
||||
|
||||
class PybindDefinedFunctionRegister {
|
||||
public:
|
||||
static void Register(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) {
|
||||
return GetSingleton().RegisterFn(name, priority, fn);
|
||||
}
|
||||
|
||||
PybindDefinedFunctionRegister(const PybindDefinedFunctionRegister &) = delete;
|
||||
|
||||
PybindDefinedFunctionRegister &operator=(const PybindDefinedFunctionRegister &) = delete;
|
||||
|
||||
static std::map<uint8_t, std::map<std::string, PybindDefineFunc>> &AllFunctions() {
|
||||
return GetSingleton().module_fns_;
|
||||
}
|
||||
std::map<uint8_t, std::map<std::string, PybindDefineFunc>> module_fns_;
|
||||
|
||||
protected:
|
||||
PybindDefinedFunctionRegister() = default;
|
||||
|
||||
virtual ~PybindDefinedFunctionRegister() = default;
|
||||
|
||||
static PybindDefinedFunctionRegister &GetSingleton();
|
||||
|
||||
void RegisterFn(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) {
|
||||
module_fns_[priority][name] = fn;
|
||||
}
|
||||
};
|
||||
|
||||
class PybindDefineRegisterer {
|
||||
public:
|
||||
PybindDefineRegisterer(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) {
|
||||
PybindDefinedFunctionRegister::Register(name, priority, fn);
|
||||
}
|
||||
~PybindDefineRegisterer() = default;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
#define PYBIND_REGISTER(name, priority, define) PybindDefineRegisterer g_pybind_define_f_##name(#name, priority, define)
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // API_PYBIND_API_H_
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue