Random Data Op

pull/869/head
Jesse Lee 5 years ago
parent 05676676e9
commit 270bf831a9

@ -28,6 +28,7 @@
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h"
@ -65,6 +66,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
@ -972,6 +974,45 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
return Status::OK();
}
Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments
RandomDataOp::Builder builder;
if (args["num_samples"].is_none()) {
std::string err_msg = "Error: num_samples is a required argument";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::vector<std::string> columns_to_load;
bool schema_exists = false;
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (key == "num_parallel_workers") {
(void)builder.SetNumWorkers(ToInt(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "num_samples") {
(void)builder.SetTotalRows(ToInt(value));
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
}
}
if (schema_exists) {
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
if (args.contains("schema_file_path")) {
RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load));
} else {
RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load));
}
(void)builder.SetDataSchema(std::move(schema));
}
std::shared_ptr<RandomDataOp> op;
RETURN_IF_NOT_OK(builder.Build(&op));
*ptr = op;
return Status::OK();
}
int32_t DEPipeline::GetNumClasses() const { return num_classes_; }
Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {

@ -60,6 +60,7 @@ enum OpName {
kCifar10,
kCifar100,
kCelebA,
kRandomData,
kTextFile
};
@ -142,6 +143,8 @@ class DEPipeline {
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
void PrintTree();
int32_t GetNumClasses() const;

@ -47,6 +47,7 @@
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
@ -489,6 +490,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("VOC", OpName::kVoc)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile);

@ -466,5 +466,24 @@ Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
"\"columns\" node is required in the schema json file.");
return Status::OK();
}
// Loops through all columns in the schema and returns a map with the column
// name to column index number.
Status DataSchema::GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map) {
if (out_column_name_map == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"unexpected null output column name map.");
}
for (int32_t i = 0; i < col_descs_.size(); ++i) {
if (col_descs_[i].name().empty()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Constructing column name map from schema, but found empty column name.");
}
(*out_column_name_map)[col_descs_[i].name()] = i;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -20,6 +20,7 @@
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <nlohmann/json.hpp>
#include "dataset/core/constants.h"
@ -180,6 +181,12 @@ class DataSchema {
static const char DEFAULT_DATA_SCHEMA_FILENAME[];
// Loops through all columns in the schema and returns a map with the column
// name to column index number.
// @param out_column_name_map - The output map of columns names to column index
// @return Status - The error code return
Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map);
private:
// Internal helper function. Parses the json schema file in any order and produces a schema that
// does not follow any particular order (json standard does not enforce any ordering protocol).

@ -17,6 +17,7 @@ add_library(engine-datasetops-source OBJECT
${FEATURE_SRCS}
manifest_op.cc
cifar_op.cc
random_data_op.cc
celeba_op.cc
text_file_op.cc
)

@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from .core.configuration import config
from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip
Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show

@ -3146,6 +3146,57 @@ class Cifar100Dataset(SourceDataset):
return get_num_rows(num_rows, self.num_shards)
class RandomDataset(SourceDataset):
"""
A source dataset that generates random data.
Args:
num_samples (int): number of samples to generate.
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from the TFRecord file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
"""
def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None):
super().__init__(num_parallel_workers)
schema_obj = None
if (schema is not None) and (not isinstance(schema, Schema)):
schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
self.schema = schema
self.columns_list = columns_list
self.num_samples = num_samples
if schema_obj is not None and num_samples is None:
self.num_samples = schema_obj.num_rows
def get_args(self):
args = super().get_args()
if self.schema is not None:
if isinstance(self.schema, Schema):
self.schema.datasetType = 'Random'
if self.num_samples is not None:
self.schema.num_rows = self.num_samples
args["schema_json_string"] = self.schema.to_json()
else:
args["schema_file_path"] = self.schema
args["schema"] = self.schema
if self.columns_list is not None:
args["columns_list"] = self.columns_list
if self.num_samples is not None:
args["num_samples"] = self.num_samples
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
return num_samples
class Schema:
"""
Class to represent a schema of dataset.

@ -192,6 +192,8 @@ class Iterator:
op_type = OpName.CIFAR100
elif isinstance(dataset, de.CelebADataset):
op_type = OpName.CELEBA
elif isinstance(dataset, de.RandomDataset):
op_type = OpName.RANDOMDATA
elif isinstance(dataset, de.TextFileDataset):
op_type = OpName.TEXTFILE
else:

File diff suppressed because it is too large Load Diff

@ -0,0 +1,14 @@
{
"columns": {
"image": {
"type": "uint8",
"rank": 3,
"shape": [1920,1080,3]
},
"label": {
"type": "int32",
"rank": 1,
"shape": [1]
}
}
}

@ -0,0 +1,14 @@
{
"columns": {
"image": {
"type": "uint8",
"rank": 2,
"shape": [28,28]
},
"label": {
"type": "uint8",
"rank": 1,
"shape": [1]
}
}
}

@ -0,0 +1,70 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from pathlib import Path
# just a basic test with parallel random data op
def test_randomdataset_basic1():
print("Test randomdataset basic")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8, shape=[2])
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# apply dataset operations
ds1 = ds.RandomDataset(schema=schema, num_samples=50, num_parallel_workers=4)
ds1 = ds1.repeat(4)
num_iter = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
print("{} image: {}".format(num_iter, data["image"]))
print("{} label: {}".format(num_iter, data["label"]))
num_iter += 1
print("Number of data in ds1: ", num_iter)
assert(num_iter == 200)
# Another simple test
def test_randomdataset_basic2():
print("Test randomdataset basic 2")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8, shape=[640,480,3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# Make up about 10 samples
ds1 = ds.RandomDataset(schema=schema, num_samples=10, num_parallel_workers=1)
# cache size allows for about 4 images since each image just a bit less than 1MB, after that we will have to spill
ds1 = ds1.repeat(4)
num_iter = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
#print(data["image"])
print("printing the label: {}".format(data["label"]))
num_iter += 1
print("Number of data in ds1: ", num_iter)
assert(num_iter == 40)
if __name__ == '__main__':
test_randomdataset_basic1()
test_randomdataset_basic2()
print('test_randomdataset_basic Ended.\n')
Loading…
Cancel
Save