diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc index eb4229be9c..a34634fd1d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc @@ -137,6 +137,10 @@ MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const j // Initialize num bytes *n_bytes = ColumnDataTypeSize[column_data_type]; auto json_column_value = columns_json[column_name]; + if (!json_column_value.is_string() && !json_column_value.is_number()) { + MS_LOG(ERROR) << "Conversion failed (" << json_column_value << ")."; + return FAILED; + } switch (column_data_type) { case ColumnFloat32: { return GetFloat(data_ptr, json_column_value, false); @@ -152,7 +156,12 @@ MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const j } default: { // Convert string to c_str - std::string tmp_string = json_column_value; + std::string tmp_string; + if (json_column_value.is_string()) { + tmp_string = json_column_value.get(); + } else { + tmp_string = json_column_value.dump(); + } *n_bytes = tmp_string.size(); auto data = reinterpret_cast(common::SafeCStr(tmp_string)); *data_ptr = std::make_unique(*n_bytes); @@ -169,10 +178,6 @@ template MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double) { std::unique_ptr array_data = std::make_unique(1); - if (!json_column_value.is_string() && !json_column_value.is_number()) { - MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; - return FAILED; - } if (json_column_value.is_number()) { array_data[0] = json_column_value; } else { diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 58ff041d35..a8a2bb785e 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1147,7 +1147,8 @@ class Dataset: 1. To save the samples in order, set dataset's shuffle to False and num_files to 1. 2. Before calling the function, do not use batch operator, repeat operator or data augmentation operators with random attribute in map operator. - 3. Mindrecord does not support DE_UINT64, multi-dimensional DE_UINT8(drop dimension) nor + 3. Can not save number type tensor whose shape is dynamic. + 4. Mindrecord does not support DE_UINT64, multi-dimensional DE_UINT8(drop dimension) nor multi-dimensional DE_STRING. Args: diff --git a/tests/ut/python/dataset/test_minddataset_padded.py b/tests/ut/python/dataset/test_minddataset_padded.py index b87e050054..97daba392a 100644 --- a/tests/ut/python/dataset/test_minddataset_padded.py +++ b/tests/ut/python/dataset/test_minddataset_padded.py @@ -136,6 +136,33 @@ def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file): assert num_padded_iter == 5 assert num_iter == 15 +def test_cv_minddataset_reader_basic_padded_samples_type_cast(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["label", "file_name", "data"] + + data = get_data(CV_DIR_NAME) + padded_sample = data[0] + padded_sample['label'] = -1 + padded_sample['file_name'] = 99999 + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, padded_sample=padded_sample, num_padded=5) + assert data_set.get_dataset_size() == 15 + num_iter = 0 + num_padded_iter = 0 + for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + if item['label'] == -1: + num_padded_iter += 1 + assert item['file_name'] == bytes(str(padded_sample['file_name']), + encoding='utf8') + assert item['label'] == padded_sample['label'] + assert (item['data'] == np.array(list(padded_sample['data']))).all() + num_iter += 1 + assert num_padded_iter == 5 + assert num_iter == 15 + def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file): """tutorial for cv minddataset."""