diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index e6009d7388..1dfbd9a3ad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -27,18 +27,14 @@ #include "proto/example.pb.h" #include "./securec.h" -#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/engine/connector.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/opt/pass.h" -#include "minddata/dataset/util/path.h" -#include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/random.h" #include "minddata/dataset/util/status.h" #include "minddata/dataset/util/task_manager.h" @@ -387,14 +383,14 @@ Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) { return Status::OK(); } -bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, +bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, const int64_t &pre_count) { *start_offset = 0; *end_offset = 0; bool push = false; int64_t start_index = device_id_ * num_rows_per_shard_; if (device_id_ + 1 < 0) { - MS_LOG(ERROR) << "Device id is invalid"; + MS_LOG(ERROR) << "Device id is invalid."; return false; } int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; @@ -448,7 +444,7 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector &i_keys) { } else { // Do an index lookup using that key to get the filename. std::string file_name = (*filename_index_)[*it]; - if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { + if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) { auto ioBlock = std::make_unique(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone); RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset; @@ -496,7 +492,7 @@ Status TFReaderOp::FillIOBlockNoShuffle() { } } else { std::string file_name = it.value(); - if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { + if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) { auto ioBlock = std::make_unique(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone); RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); queue_index = (queue_index + 1) % num_workers_; @@ -711,7 +707,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table // reinitializes itself so that it can be executed again, as if it was just created. Status TFReaderOp::Reset() { MS_LOG(DEBUG) << Name() << " performing a self-reset."; - // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true + // start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true load_jagged_connector_ = true; { @@ -767,6 +763,14 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng new_pad_size *= cur_shape[i]; } pad_size = new_pad_size; + } else { + if (cur_shape.known() && cur_shape.NumOfElements() != max_size) { + std::string err_msg = "Shape in schema's column '" + current_col.name() + "' is incorrect." + + "\nshape received: " + cur_shape.ToString() + + "\ntotal elements in shape received: " + std::to_string(cur_shape.NumOfElements()) + + "\nexpected total elements in shape: " + std::to_string(max_size); + RETURN_STATUS_UNEXPECTED(err_msg); + } } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h index 748979f50e..217ed4787b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -387,7 +387,7 @@ class TFReaderOp : public ParallelOp { // @param end_file - If file contains the end sample of data. // @param pre_count - Total rows of previous files. // @return Status - the error code returned. - bool NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, const int64_t &pre_count); // Caculate number of rows in each shard. diff --git a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc index 04b7c89716..c65590aada 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc @@ -491,3 +491,17 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) { std::shared_ptr iter = ds->CreateIterator(); EXPECT_EQ(iter, nullptr); } + +TEST_F(MindDataTestPipeline, TestIncorrectTFSchemaObject) { + std::string path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; + std::shared_ptr schema = api::Schema(); + schema->add_column("image", "uint8", {1}); + schema->add_column("label", "int64", {1}); + std::shared_ptr ds = api::TFRecord({path}, schema); + EXPECT_NE(ds, nullptr); + auto itr = ds->CreateIterator(); + EXPECT_NE(itr, nullptr); + TensorMap mp; + // this will fail due to the incorrect schema used + EXPECT_FALSE(itr->GetNextRow(&mp)); +} diff --git a/tests/ut/python/dataset/test_datasets_tfrecord.py b/tests/ut/python/dataset/test_datasets_tfrecord.py index 35e13a859f..fec82b14de 100644 --- a/tests/ut/python/dataset/test_datasets_tfrecord.py +++ b/tests/ut/python/dataset/test_datasets_tfrecord.py @@ -294,6 +294,24 @@ def test_tfrecord_invalid_files(): assert nonexistent_file in str(info.value) +def test_tf_wrong_schema(): + logger.info("test_tf_wrong_schema") + files = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data"] + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, shape=[1]) + schema.add_column('label', de_type=mstype.int64, shape=[1]) + data1 = ds.TFRecordDataset(files, schema, shuffle=False) + exception_occurred = False + try: + for _ in data1: + pass + except RuntimeError as e: + exception_occurred = True + assert "Shape in schema's column 'image' is incorrect" in str(e) + + assert exception_occurred, "test_tf_wrong_schema failed." + + if __name__ == '__main__': test_tfrecord_shape() test_tfrecord_read_all_dataset() @@ -312,3 +330,4 @@ if __name__ == '__main__': test_tfrecord_no_schema_columns_list() test_tfrecord_schema_columns_list() test_tfrecord_invalid_files() + test_tf_wrong_schema()