add floatxx test case

pull/2015/head
jonyguo 5 years ago
parent 9dfb1011fe
commit 1de7271afc

@ -335,15 +335,15 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
if (field_type == "INTEGER") { if (field_type == "INTEGER") {
if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) { if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoi(field_value); << ", field value: " << std::stoll(field_value);
return FAILED; return FAILED;
} }
} else if (field_type == "NUMERIC") { } else if (field_type == "NUMERIC") {
if (sqlite3_bind_double(stmt, index, std::stod(field_value)) != SQLITE_OK) { if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoi(field_value); << ", field value: " << std::stold(field_value);
return FAILED; return FAILED;
} }
} else if (field_type == "NULL") { } else if (field_type == "NULL") {

@ -17,6 +17,7 @@ This is the test module for mindrecord
""" """
import collections import collections
import json import json
import math
import os import os
import re import re
import string import string
@ -1605,3 +1606,149 @@ def test_write_with_multi_array_and_MindDataset():
os.remove("{}".format(mindrecord_file_name)) os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))
def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset():
mindrecord_file_name = "test.mindrecord"
data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12345,
"float64": 1987654321.123456785,
"int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32),
"int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64),
"int32": 3456,
"int64": 947654321123},
{"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12445,
"float64": 1987654321.123456786,
"int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32),
"int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64),
"int32": 3466,
"int64": 957654321123},
{"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12545,
"float64": 1987654321.123456787,
"int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32),
"int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64),
"int32": 3476,
"int64": 967654321123},
{"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12645,
"float64": 1987654321.123456788,
"int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32),
"int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64),
"int32": 3486,
"int64": 977654321123},
{"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12745,
"float64": 1987654321.123456789,
"int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32),
"int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64),
"int32": 3496,
"int64": 987654321123},
]
writer = FileWriter(mindrecord_file_name)
schema = {"float32_array": {"type": "float32", "shape": [-1]},
"float64_array": {"type": "float64", "shape": [-1]},
"float32": {"type": "float32"},
"float64": {"type": "float64"},
"int32_array": {"type": "int32", "shape": [-1]},
"int64_array": {"type": "int64", "shape": [-1]},
"int32": {"type": "int32"},
"int64": {"type": "int64"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
# change data value to list - do none
data_value_to_list = []
for item in data:
new_data = {}
new_data['float32_array'] = item["float32_array"]
new_data['float64_array'] = item["float64_array"]
new_data['float32'] = item["float32"]
new_data['float64'] = item["float64"]
new_data['int32_array'] = item["int32_array"]
new_data['int64_array'] = item["int64_array"]
new_data['int32'] = item["int32"]
new_data['int64'] = item["int64"]
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 8
for field in item:
if isinstance(item[field], np.ndarray):
if item[field].dtype == np.float32:
assert (item[field] ==
np.array(data_value_to_list[num_iter][field], np.float32)).all()
else:
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 5
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["float32", "int32"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 2
for field in item:
if isinstance(item[field], np.ndarray):
if item[field].dtype == np.float32:
assert (item[field] ==
np.array(data_value_to_list[num_iter][field], np.float32)).all()
else:
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 5
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["float64", "int64"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 2
for field in item:
if isinstance(item[field], np.ndarray):
if item[field].dtype == np.float32:
assert (item[field] ==
np.array(data_value_to_list[num_iter][field], np.float32)).all()
elif item[field].dtype == np.float64:
assert math.isclose(item[field],
np.array(data_value_to_list[num_iter][field], np.float64), rel_tol=1e-14)
else:
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 5
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))

Loading…
Cancel
Save