!3049 [dataset] add save operator in dataset

Merge pull request !3049 from liyong126/dataset_save_op
pull/3049/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 526770e09e

@ -42,11 +42,17 @@
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_category.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_header.h"
#include "minddata/mindrecord/include/shard_index_generator.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/mindrecord/include/shard_writer.h"
#include "pybind11/stl.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
using json = nlohmann::json;
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *, std::shared_ptr<DatasetOp> *);
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
@ -355,6 +361,226 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
return Status::OK();
}
Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type) {
Status s;
auto mr_header = std::make_shared<mindrecord::ShardHeader>();
auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
std::vector<std::string> blob_fields;
uint64_t mr_schema_id = 0;
if (mindrecord::SUCCESS != mindrecord::ShardWriter::initialize(&mr_writer, file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter.");
}
TensorRow row;
std::unordered_map<std::string, int32_t> column_name_id_map =
iterator_->GetColumnNameMap(); // map of column name, id
bool first_loop = true; // build schema in first loop
do {
json row_raw_data;
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
{
py::gil_scoped_release gil_release;
s = iterator_->FetchNextTensorRow(&row);
}
RETURN_IF_NOT_OK(s);
if (row.empty()) break;
if (first_loop) {
json mr_json;
std::vector<std::string> index_fields;
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
RETURN_IF_NOT_OK(s);
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
mr_writer->SetShardHeader(mr_header);
first_loop = false;
}
// construct data
if (!row.empty()) { // write data
s = FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data);
RETURN_IF_NOT_OK(s);
std::shared_ptr<std::vector<uint8_t>> output_bin_data;
mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data);
std::map<std::uint64_t, std::vector<json>> raw_data;
raw_data.insert(std::pair<uint64_t, std::vector<json>>(mr_schema_id, std::vector<json>{row_raw_data}));
std::vector<std::vector<uint8_t>> bin_data;
if (nullptr != output_bin_data) {
bin_data.emplace_back(*output_bin_data);
}
mr_writer->WriteRawData(raw_data, bin_data);
}
} while (!row.empty());
mr_writer->Commit();
mindrecord::ShardIndexGenerator::finalize(file_names);
return Status::OK();
}
Status DEPipeline::FetchDataFromTensorRow(const TensorRow &row,
const std::unordered_map<std::string, int32_t> &column_name_id_map,
json *row_raw_data,
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
if (row_raw_data == nullptr) {
RETURN_STATUS_UNEXPECTED("error: row raw data is NULL.");
}
if (row_bin_data == nullptr) {
RETURN_STATUS_UNEXPECTED("error: row bin data is NULL.");
}
if (column_name_id_map.empty()) {
RETURN_STATUS_UNEXPECTED("Error: column not found");
}
Status s;
for (auto &col : column_name_id_map) {
auto idx = col.second;
auto column_name = col.first;
auto &tensor = row[idx];
auto column_type = tensor->type();
std::unique_ptr<std::vector<uint8_t>> data_ptr;
if (column_type == DataType::DE_INT8) {
std::unique_ptr<int32_t> data;
std::unique_ptr<int8_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT16) {
std::unique_ptr<int32_t> data;
std::unique_ptr<int16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT16) {
std::unique_ptr<int32_t> data;
std::unique_ptr<uint16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT8) {
std::unique_ptr<uint8_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT32) {
std::unique_ptr<int32_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT32) {
std::unique_ptr<int64_t> data;
std::unique_ptr<uint32_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT64) {
std::unique_ptr<int64_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT32) {
std::unique_ptr<float> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT64) {
std::unique_ptr<double> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_STRING) {
auto buffer = tensor->GetStringsBuffer();
std::string ss(reinterpret_cast<const char *>(buffer)); // assume scalar string tensor
(*row_raw_data)[column_name] = std::move(ss);
continue;
} else {
RETURN_STATUS_UNEXPECTED("Got unexpected type when casting data.");
}
RETURN_IF_NOT_OK(s);
if (data_ptr != nullptr) {
(*row_bin_data)[column_name] = std::move(data_ptr);
}
}
return Status::OK();
}
template <typename T, typename S>
Status DEPipeline::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert) {
if (nullptr == src) {
RETURN_STATUS_UNEXPECTED("Error: buffer of Tensor is NULL.");
}
*data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T));
if (need_convert) {
auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S));
std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin());
auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin())));
auto el = std::make_unique<T>();
for (uint32_t i = 0; i < num_of_elements; ++i) {
*el = *(s_ptr + i);
auto t_ptr = reinterpret_cast<uint8_t *>(el.get());
for (uint32_t j = 0; j < sizeof(T); ++j) {
*((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j);
}
}
} else {
std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin());
}
if (shape.empty()) {
*data = std::make_unique<T>();
auto t_ptr = reinterpret_cast<uint8_t *>((*data).get());
for (uint32_t i = 0; i < sizeof(T); ++i) {
*(t_ptr + i) = *((*data_ptr)->begin() + i);
}
}
return Status::OK();
}
Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
const TensorRow &row, json *schema, std::vector<std::string> *index_fields) {
if (schema == nullptr) {
RETURN_STATUS_UNEXPECTED("error: schema is NULL.");
}
if (index_fields == nullptr) {
RETURN_STATUS_UNEXPECTED("error: index fields is NULL.");
}
if (column_name_id_map.empty()) {
RETURN_STATUS_UNEXPECTED("Error: column not found.");
}
for (auto &col : column_name_id_map) {
auto idx = col.second;
auto column_name = col.first;
auto &tensor = row[idx];
auto column_type = tensor->type();
auto column_shape = tensor->shape();
std::string mr_type;
auto shapes = column_shape.AsVector();
std::vector<int> mr_shape(shapes.begin(), shapes.end());
std::string el = column_type.ToString();
if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
std::string err_msg("Error: can not support data type: " + el);
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
mr_type = mindrecord::kTypesMap.at(el);
}
if (mr_shape.empty()) {
if (mr_type == "bytes") { // map to int32 when bytes without shape.
mr_type == "int32";
}
(*schema)[column_name] = {{"type", mr_type}};
} else {
if (mr_type == "string") { // mindrecord can not support string with shape.
std::string err_msg("Error: mindrecord can not support multi-dimensional string tensor.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (mr_type == "bytes") { // ignore shape of bytes in minrecord
(*schema)[column_name] = {{"type", mr_type}};
} else {
(*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}};
}
}
if (mr_type == "bytes" || !mr_shape.empty()) continue;
index_fields->emplace_back(column_name); // candidate of index fields
}
return Status::OK();
}
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded) {

@ -17,6 +17,7 @@
#define DATASET_API_DE_PIPELINE_H_
#include <iostream>
#include <map>
#include <memory>
#include <stack>
#include <string>
@ -33,6 +34,7 @@
namespace py = pybind11;
namespace mindspore {
namespace dataset {
using json = nlohmann::json;
using DsOpPtr = std::shared_ptr<DatasetOp>;
class CacheClient;
@ -100,6 +102,8 @@ class DEPipeline {
Status GetOutputTypes(py::list *output);
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
int GetDatasetSize() const;
int GetBatchSize() const;
@ -110,6 +114,18 @@ class DEPipeline {
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
template <typename T, typename S>
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert = false);
Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
const TensorRow &row, json *schema, std::vector<std::string> *index_fields);
Status FetchDataFromTensorRow(const TensorRow &row,
const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data,
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
Status BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded);

@ -184,7 +184,11 @@ void bindDEPipeline(py::module *m) {
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
.def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount);
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;
});
}
void bindDatasetOps(py::module *m) {
(void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")

@ -312,6 +312,11 @@ class Tensor {
// @return const unsigned char*
const unsigned char *GetBuffer() const;
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
// tensor's type is a string, otherwise undefined address would be returned.
// @return address of the first string of the tensor.
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; }
// Getter of the type
// @return
DataType type() const { return type_; }
@ -643,11 +648,6 @@ class Tensor {
// @return length of the string
Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const;
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
// tensor's type is a string, otherwise undefined address would be returned.
// @return address of the first string of the tensor.
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; }
// all access to shape_ should be via shape
TensorShape shape_;
// data type of tensor

@ -215,7 +215,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\n Dataset file : ";
out << "\nDataset file : ";
for (auto &file : dataset_file_) {
out << file << " ";
}

@ -137,6 +137,10 @@ const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "
// number field list
const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"};
const std::unordered_map<std::string, std::string> kTypesMap = {
{"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"},
{"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"},
{"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}};
/// \brief split a string using a character
/// \param[in] field target string
/// \param[in] separator a character for spliting

@ -124,6 +124,10 @@ class ShardHeader {
MSRStatus FileToPages(const std::string dump_file_name);
static MSRStatus initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id);
private:
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);

@ -57,6 +57,8 @@ class ShardIndexGenerator {
/// \brief create databases for indexes
MSRStatus WriteToDatabase();
static MSRStatus finalize(const std::vector<std::string> file_names);
private:
static int Callback(void *not_used, int argc, char **argv, char **az_col_name);

@ -108,6 +108,13 @@ class ShardWriter {
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);
MSRStatus MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output);
static MSRStatus initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
const std::vector<std::string> &file_names);
private:
/// \brief write shard header data to disk
MSRStatus WriteShardHeader();

@ -622,5 +622,21 @@ void ShardIndexGenerator::DatabaseWriter() {
shard_no = task_++;
}
}
MSRStatus ShardIndexGenerator::finalize(const std::vector<std::string> file_names) {
if (file_names.empty()) {
MS_LOG(ERROR) << "Mindrecord files is empty.";
return FAILED;
}
ShardIndexGenerator sg{file_names[0]};
if (SUCCESS != sg.Build()) {
MS_LOG(ERROR) << "Failed to build index generator.";
return FAILED;
}
if (SUCCESS != sg.WriteToDatabase()) {
MS_LOG(ERROR) << "Failed to write to database.";
return FAILED;
}
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore

@ -637,6 +637,42 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
*row_count = std::get<2>(v);
return SUCCESS;
}
MSRStatus ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output) {
if (blob_fields.empty()) {
return SUCCESS;
}
if (blob_fields.size() == 1) {
auto &blob = row_bin_data.at(blob_fields[0]);
auto blob_size = blob->size();
*output = std::make_shared<std::vector<uint8_t>>(blob_size);
std::copy(blob->begin(), blob->end(), (*output)->begin());
} else {
size_t output_size = 0;
for (auto &field : blob_fields) {
output_size += row_bin_data.at(field)->size();
}
output_size += blob_fields.size() * sizeof(uint64_t);
*output = std::make_shared<std::vector<uint8_t>>(output_size);
std::vector<uint8_t> buf(sizeof(uint64_t), 0);
size_t idx = 0;
for (auto &field : blob_fields) {
auto &blob = row_bin_data.at(field);
uint64_t blob_size = blob->size();
// big edian
for (size_t i = 0; i < buf.size(); ++i) {
buf[buf.size() - 1 - i] = std::numeric_limits<uint8_t>::max() & blob_size;
blob_size >>= 8u;
}
std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
idx += buf.size();
std::copy(blob->begin(), blob->end(), (*output)->begin() + idx);
idx += blob->size();
}
}
return SUCCESS;
}
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
@ -1250,5 +1286,21 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &la
last_blob_page = page.first;
}
}
MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
const std::vector<std::string> &file_names) {
if (nullptr == writer_ptr) {
MS_LOG(ERROR) << "ShardWriter pointer is NULL.";
return FAILED;
}
auto res = (*writer_ptr)->Open(file_names, false);
if (SUCCESS != res) {
MS_LOG(ERROR) << "Failed to open mindrecord files to writer.";
return FAILED;
}
(*writer_ptr)->SetHeaderSize(1 << 24);
(*writer_ptr)->SetPageSize(1 << 25);
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore

@ -721,5 +721,35 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
page_in_handle.close();
return SUCCESS;
}
MSRStatus ShardHeader::initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id) {
if (nullptr == header_ptr) {
MS_LOG(ERROR) << "ShardHeader pointer is NULL.";
return FAILED;
}
auto schema_ptr = Schema::Build("mindrecord", schema);
if (nullptr == schema_ptr) {
MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema.";
return FAILED;
}
schema_id = (*header_ptr)->AddSchema(schema_ptr);
// create index
std::vector<std::pair<uint64_t, std::string>> id_index_fields;
if (!index_fields.empty()) {
for (auto &el : index_fields) {
id_index_fields.emplace_back(schema_id, el);
}
if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) {
MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index.";
return FAILED;
}
}
auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
blob_fields = build_schema_ptr->GetBlobFields();
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore

@ -38,13 +38,13 @@ from mindspore._c_expression import typing
from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
@ -1044,6 +1044,34 @@ class Dataset:
return TransferDataset(self, queue_name, device_id, device_type, num_batch)
@check_save
def save(self, file_name, num_files=1, file_type='mindrecord'):
"""
Save the dynamic data processed by dataset pipeline as common dataset format, support: mindrecord.
Note:
1. To save the samples in order, should set dataset's shuffle false and num_files 1.
2. Before call the function, do not use batch, repeat operator or data augmentation operators
with random attribute in map operator.
3. Mindreocrd do not support np.uint64, multi-dimensional np.uint8(drop dimension) and
multi-dimensional string.
Args:
file_name (str): Path to dataset file.
num_files (int, optional): Number of dataset files.(default=1).
file_type (str, optional): dataset format.(default='mindrecord')
"""
if num_files == 1:
file_names = [file_name]
else:
suffix = len(str(num_files - 1))
file_names = ["{}{}".format(file_name, str(x).rjust(suffix, '0'))
for x in range(num_files)]
return SaveOp(self).save(file_names, file_type)
def create_tuple_iterator(self, columns=None):
"""
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.

@ -173,6 +173,7 @@ class Iterator:
# Convert python node into C node and add to C layer execution tree in postorder traversal.
def __convert_node_postorder(self, node):
self.check_node_type(node)
op_type = self.__get_dataset_type(node)
c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
@ -224,6 +225,10 @@ class Iterator:
self._index += 1
return data
@abstractmethod
def check_node_type(self, node):
pass
def get_output_shapes(self):
return [t for t in self.depipeline.GetOutputShapes()]
@ -245,11 +250,27 @@ class Iterator:
def __deepcopy__(self, memo):
return self
class SaveOp(Iterator):
"""
The derived class of Iterator with dict type.
"""
def get_next(self):
pass
def check_node_type(self, node):
if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
logger.warning("Used shuffle, repeat, batch before save operator.")
def save(self, file_names, file_type):
return self.depipeline.SaveDataset(file_names, file_type)
class DictIterator(Iterator):
"""
The derived class of Iterator with dict type.
"""
def check_node_type(self, node):
pass
def __iter__(self):
return self
@ -269,6 +290,8 @@ class TupleIterator(Iterator):
"""
The derived class of Iterator with list type.
"""
def check_node_type(self, node):
pass
def __init__(self, dataset, columns=None):
if columns is not None:

@ -246,7 +246,24 @@ def check_celebadataset(method):
return new_method
def check_save(method):
"""A wrapper that wrap a parameter checker to the save op."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_files']
nreq_param_str = ['file_name', 'file_type']
validate_dataset_param_value(nreq_param_int, param_dict, int)
if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
raise ValueError("num_files should between {} and {}.".format(1, 1000))
validate_dataset_param_value(nreq_param_str, param_dict, str)
if param_dict.get('file_type') != 'mindrecord':
raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
return method(self, *args, **kwargs)
return new_method
def check_minddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save