!503 The num_samples and numRows in schema for TFRecordDataset are conflict

Merge pull request !503 from qianlong21st/fix_numRows_num_samples
pull/503/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6b0ff88b1c

@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
std::ifstream in(schemaFile); std::ifstream in(schemaFile);
nlohmann::json js; nlohmann::json js;
in >> js; in >> js;
num_rows = js.value("numRows", 0); if (js.find("numRows") == js.end()) {
num_rows = MAX_INTEGER_INT32;
} else {
num_rows = js.value("numRows", 0);
}
if (num_rows == 0) { if (num_rows == 0) {
std::string err_msg = std::string err_msg =
"Storage client has not properly done dataset " "Storage client has not properly done dataset "

@ -163,6 +163,9 @@ Status TFReaderOp::Init() {
if (total_rows_ == 0) { if (total_rows_ == 0) {
total_rows_ = data_schema_->num_rows(); total_rows_ = data_schema_->num_rows();
} }
if (total_rows_ < 0) {
RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0");
}
// Build the index with our files such that each file corresponds to a key id. // Build the index with our files such that each file corresponds to a key id.
RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));

@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset):
Args: Args:
dataset_files (list[str]): List of files to be read. dataset_files (list[str]): List of files to be read.
schema (str): Path to the json schema file. schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset.
distribution (str, optional): Path of distribution config file (default=""). distribution (str, optional): Path of distribution config file (default="").
columns_list (list[str], optional): List of columns to be read (default=None, read all columns). columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
num_parallel_workers (int, optional): Number of parallel working threads (default=None). num_parallel_workers (int, optional): Number of parallel working threads (default=None).
@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset):
schema (str or Schema, optional): Path to the json schema file or schema object (default=None). schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from the TFData file is considered the schema. If the schema is not provided, the meta data from the TFData file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns) columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). num_samples (int, optional): number of samples(rows) to read (default=None).
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
num_parallel_workers (int, optional): number of workers to read the data num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config). (default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
@ -2711,10 +2714,10 @@ class Schema:
""" """
def __init__(self, schema_file=None): def __init__(self, schema_file=None):
self.num_rows = None
if schema_file is None: if schema_file is None:
self.columns = [] self.columns = []
self.dataset_type = '' self.dataset_type = ''
self.num_rows = 0
else: else:
if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK): if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
raise ValueError("The file %s does not exist or permission denied!" % schema_file) raise ValueError("The file %s does not exist or permission denied!" % schema_file)
@ -2859,6 +2862,9 @@ class Schema:
raise RuntimeError("DatasetType field is missing.") raise RuntimeError("DatasetType field is missing.")
if self.columns is None: if self.columns is None:
raise RuntimeError("Columns are missing.") raise RuntimeError("Columns are missing.")
if self.num_rows is not None:
if not isinstance(self.num_rows, int) or self.num_rows <= 0:
raise ValueError("numRows must be greater than 0")
def __str__(self): def __str__(self):
return self.to_json() return self.to_json()

@ -0,0 +1,45 @@
{
"datasetType": "TF",
"columns": {
"col_sint16": {
"type": "int16",
"rank": 1,
"shape": [1]
},
"col_sint32": {
"type": "int32",
"rank": 1,
"shape": [1]
},
"col_sint64": {
"type": "int64",
"rank": 1,
"shape": [1]
},
"col_float": {
"type": "float32",
"rank": 1,
"shape": [1]
},
"col_1d": {
"type": "int64",
"rank": 1,
"shape": [2]
},
"col_2d": {
"type": "int64",
"rank": 2,
"shape": [2, 2]
},
"col_3d": {
"type": "int64",
"rank": 3,
"shape": [2, 2, 2]
},
"col_binary": {
"type": "uint8",
"rank": 1,
"shape": [1]
}
}
}

@ -0,0 +1,15 @@
{
"datasetType": "TF",
"columns": {
"image": {
"type": "uint8",
"rank": 1,
"t_impl": "cvmat"
},
"label" : {
"type": "uint64",
"rank": 1,
"t_impl": "flex"
}
}
}

@ -37,3 +37,15 @@ def test_case_storage():
filename = "storage_result.npz" filename = "storage_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_no_rows():
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json"
dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
assert dataset.get_dataset_size() == 3
count = 0
for data in dataset.create_tuple_iterator():
count += 1
assert count == 3

@ -37,6 +37,36 @@ def test_case_tf_shape():
assert (len(output_shape[-1]) == 1) assert (len(output_shape[-1]) == 1)
def test_case_tf_read_all_dataset():
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 12
count = 0
for data in ds1.create_tuple_iterator():
count += 1
assert count == 12
def test_case_num_samples():
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
assert ds1.get_dataset_size() == 8
count = 0
for data in ds1.create_dict_iterator():
count += 1
assert count == 8
def test_case_num_samples2():
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 7
count = 0
for data in ds1.create_dict_iterator():
count += 1
assert count == 7
def test_case_tf_shape_2(): def test_case_tf_shape_2():
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
ds1 = ds1.batch(2) ds1 = ds1.batch(2)

Loading…
Cancel
Save