Fixed GetDatasetSize for TextFile

pull/8713/head
Mahdi 4 years ago
parent fedb225a96
commit 449e1526dc

@ -529,7 +529,7 @@ Status TextFileOp::GetDatasetSize(int64_t *dataset_size) {
int64_t num_rows, sample_size;
sample_size = total_rows_;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = total_rows_;
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();

@ -105,6 +105,10 @@ TEST_F(MindDataTestPipeline, TestTextFileGetters) {
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
ds = TextFile({tf_file1}, 0);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 3);
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);

Loading…
Cancel
Save