From 1de7271afc0859283de2c04de01efd40be2caf1e Mon Sep 17 00:00:00 2001 From: jonyguo Date: Thu, 11 Jun 2020 18:11:09 +0800 Subject: [PATCH] add floatxx test case --- .../mindrecord/io/shard_index_generator.cc | 8 +- tests/ut/python/dataset/test_minddataset.py | 147 ++++++++++++++++++ 2 files changed, 151 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index 905968e3a2..f72db49e20 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -335,15 +335,15 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); if (field_type == "INTEGER") { - if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) { + if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stoi(field_value); + << ", field value: " << std::stoll(field_value); return FAILED; } } else if (field_type == "NUMERIC") { - if (sqlite3_bind_double(stmt, index, std::stod(field_value)) != SQLITE_OK) { + if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stoi(field_value); + << ", field value: " << std::stold(field_value); return FAILED; } } else if (field_type == "NULL") { diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 986fc6b665..0ef43a5bc7 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -17,6 +17,7 @@ This is the test module for mindrecord """ import collections import json +import math import os import re import string @@ -1605,3 +1606,149 @@ def test_write_with_multi_array_and_MindDataset(): os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(): + mindrecord_file_name = "test.mindrecord" + data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12345, + "float64": 1987654321.123456785, + "int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32), + "int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64), + "int32": 3456, + "int64": 947654321123}, + {"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12445, + "float64": 1987654321.123456786, + "int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32), + "int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64), + "int32": 3466, + "int64": 957654321123}, + {"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12545, + "float64": 1987654321.123456787, + "int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32), + "int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64), + "int32": 3476, + "int64": 967654321123}, + {"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12645, + "float64": 1987654321.123456788, + "int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32), + "int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64), + "int32": 3486, + "int64": 977654321123}, + {"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12745, + "float64": 1987654321.123456789, + "int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32), + "int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64), + "int32": 3496, + "int64": 987654321123}, + ] + writer = FileWriter(mindrecord_file_name) + schema = {"float32_array": {"type": "float32", "shape": [-1]}, + "float64_array": {"type": "float64", "shape": [-1]}, + "float32": {"type": "float32"}, + "float64": {"type": "float64"}, + "int32_array": {"type": "int32", "shape": [-1]}, + "int64_array": {"type": "int64", "shape": [-1]}, + "int32": {"type": "int32"}, + "int64": {"type": "int64"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list - do none + data_value_to_list = [] + for item in data: + new_data = {} + new_data['float32_array'] = item["float32_array"] + new_data['float64_array'] = item["float64_array"] + new_data['float32'] = item["float32"] + new_data['float64'] = item["float64"] + new_data['int32_array'] = item["int32_array"] + new_data['int64_array'] = item["int64_array"] + new_data['int32'] = item["int32"] + new_data['int64'] = item["int64"] + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 8 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["float32", "int32"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["float64", "int64"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + elif item[field].dtype == np.float64: + assert math.isclose(item[field], + np.array(data_value_to_list[num_iter][field], np.float64), rel_tol=1e-14) + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name))