|
|
|
@ -16,10 +16,9 @@
|
|
|
|
|
#include "minddata/dataset/api/python/de_pipeline.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
#include "utils/ms_utils.h"
|
|
|
|
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
|
|
|
|
#include "minddata/dataset/core/tensor.h"
|
|
|
|
|
#include "minddata/dataset/engine/cache/cache_client.h"
|
|
|
|
@ -32,15 +31,15 @@
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
|
|
|
|
#include "minddata/dataset/kernels/py_func_op.h"
|
|
|
|
|
#include "minddata/dataset/util/random.h"
|
|
|
|
|
#include "minddata/dataset/util/status.h"
|
|
|
|
@ -53,6 +52,7 @@
|
|
|
|
|
#include "minddata/mindrecord/include/shard_writer.h"
|
|
|
|
|
#include "pybind11/stl.h"
|
|
|
|
|
#include "utils/log_adapter.h"
|
|
|
|
|
#include "utils/ms_utils.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
@ -211,16 +211,17 @@ Status DEPipeline::GetColumnNames(py::list *output) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status DEPipeline::GetNextAsMap(py::dict *output) {
|
|
|
|
|
TensorMap row;
|
|
|
|
|
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec;
|
|
|
|
|
Status s;
|
|
|
|
|
{
|
|
|
|
|
py::gil_scoped_release gil_release;
|
|
|
|
|
s = iterator_->GetNextAsMap(&row);
|
|
|
|
|
s = iterator_->GetNextAsOrderedPair(&vec);
|
|
|
|
|
}
|
|
|
|
|
RETURN_IF_NOT_OK(s);
|
|
|
|
|
// Generate Python dict as return
|
|
|
|
|
for (auto el : row) {
|
|
|
|
|
(*output)[common::SafeCStr(el.first)] = el.second;
|
|
|
|
|
|
|
|
|
|
// Generate Python dict, python dict maintains its insertion order
|
|
|
|
|
for (const auto &pair : vec) {
|
|
|
|
|
(*output)[common::SafeCStr(pair.first)] = pair.second;
|
|
|
|
|
}
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
@ -614,7 +615,7 @@ Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map<std::string,
|
|
|
|
|
}
|
|
|
|
|
if (mr_shape.empty()) {
|
|
|
|
|
if (mr_type == "bytes") { // map to int32 when bytes without shape.
|
|
|
|
|
mr_type == "int32";
|
|
|
|
|
mr_type = "int32";
|
|
|
|
|
}
|
|
|
|
|
(*schema)[column_name] = {{"type", mr_type}};
|
|
|
|
|
} else {
|
|
|
|
@ -905,7 +906,7 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|
|
|
|
if (py::isinstance<py::int_>(args["batch_size"])) {
|
|
|
|
|
batch_size_ = ToInt(args["batch_size"]);
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid.");
|
|
|
|
|
builder = std::make_shared<BatchOp::Builder>(ToInt(args["batch_size"]));
|
|
|
|
|
builder = std::make_shared<BatchOp::Builder>(batch_size_);
|
|
|
|
|
} else if (py::isinstance<py::function>(args["batch_size"])) {
|
|
|
|
|
builder = std::make_shared<BatchOp::Builder>(1);
|
|
|
|
|
(void)builder->SetBatchSizeFunc(args["batch_size"].cast<py::function>());
|
|
|
|
@ -920,17 +921,13 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|
|
|
|
if (!value.is_none()) {
|
|
|
|
|
if (key == "drop_remainder") {
|
|
|
|
|
(void)builder->SetDrop(ToBool(value));
|
|
|
|
|
}
|
|
|
|
|
if (key == "num_parallel_workers") {
|
|
|
|
|
} else if (key == "num_parallel_workers") {
|
|
|
|
|
(void)builder->SetNumWorkers(ToInt(value));
|
|
|
|
|
}
|
|
|
|
|
if (key == "per_batch_map") {
|
|
|
|
|
} else if (key == "per_batch_map") {
|
|
|
|
|
(void)builder->SetBatchMapFunc(value.cast<py::function>());
|
|
|
|
|
}
|
|
|
|
|
if (key == "input_columns") {
|
|
|
|
|
} else if (key == "input_columns") {
|
|
|
|
|
(void)builder->SetColumnsToMap(ToStringVector(value));
|
|
|
|
|
}
|
|
|
|
|
if (key == "pad_info") {
|
|
|
|
|
} else if (key == "pad_info") {
|
|
|
|
|
PadInfo pad_info;
|
|
|
|
|
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
|
|
|
|
|
(void)builder->SetPaddingMap(pad_info, true);
|
|
|
|
|