diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 935150c361..797383f131 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -529,7 +529,7 @@ Status TextFileOp::GetDatasetSize(int64_t *dataset_size) { int64_t num_rows, sample_size; sample_size = total_rows_; if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - num_rows = total_rows_; + num_rows = num_rows_per_shard_; *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc index 73ebc7bdae..1dbc1f1b13 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc @@ -105,6 +105,10 @@ TEST_F(MindDataTestPipeline, TestTextFileGetters) { EXPECT_EQ(ds->GetDatasetSize(), 2); EXPECT_EQ(ds->GetColumnNames(), column_names); + ds = TextFile({tf_file1}, 0); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 3); // Restore configuration GlobalContext::config_manager()->set_seed(original_seed); GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);