diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 148853a5bc..bc99c85fcb 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -655,6 +655,42 @@ Status Dataset::AddCacheOp(std::vector> *node_ops) { } return Status::OK(); } +int64_t Dataset::GetBatchSize() { + int64_t batch_size; + auto ds = shared_from_this(); + Status rc; + std::unique_ptr runtime_context = std::make_unique(); + rc = runtime_context->Init(); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; + return -1; + } + rc = tree_getters_->Init(ds); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; + return -1; + } + rc = tree_getters_->GetBatchSize(&batch_size); + return rc.IsError() ? -1 : batch_size; +} +int64_t Dataset::GetRepeatCount() { + int64_t repeat_count; + auto ds = shared_from_this(); + Status rc; + std::unique_ptr runtime_context = std::make_unique(); + rc = runtime_context->Init(); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; + return -1; + } + rc = tree_getters_->Init(ds); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; + return -1; + } + rc = tree_getters_->GetRepeatCount(&repeat_count); + return rc.IsError() ? 0 : repeat_count; +} SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 3aa4bf3bc3..53910bdcce 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -430,4 +430,18 @@ Status TreeGetters::GetOutputShapes(std::vector *shapes) { } return Status::OK(); } + +Status TreeGetters::GetBatchSize(int64_t *batch_size) { + std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); + CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); + *batch_size = root->GetTreeBatchSize(); + CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size."); + return Status::OK(); +} +Status TreeGetters::GetRepeatCount(int64_t *repeat_count) { + std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); + CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); + *repeat_count = root->GetTreeRepeatCount(); + return Status::OK(); +} } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index e4376d14b0..c7616f0034 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -162,6 +162,8 @@ class TreeGetters : public TreeConsumer { Status GetDatasetSize(int64_t *size); Status GetOutputTypes(std::vector *types); Status GetOutputShapes(std::vector *shapes); + Status GetBatchSize(int64_t *batch_size); + Status GetRepeatCount(int64_t *repeat_count); bool isInitialized(); std::string Name() override { return "TreeGetters"; } Status GetRow(TensorRow *r); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 28b2e88932..7b10f9bfc0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -555,6 +555,14 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) { dataset_size_ = num_rows; return Status::OK(); } +int64_t BatchOp::GetTreeBatchSize() { +#ifdef ENABLE_PYTHON + if (batch_size_func_) { + return -1; + } +#endif + return start_batch_size_; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h index 1e6a66efa8..9d1cd9539c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -224,6 +224,8 @@ class BatchOp : public ParallelOp { /// \return Status of the function Status GetDatasetSize(int64_t *dataset_size) override; + int64_t GetTreeBatchSize() override; + protected: Status ComputeColMap() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index e656f1724f..8181e17896 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -455,5 +455,17 @@ void DatasetOp::UpdateRepeatAndEpochCounter() { if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; } +int64_t DatasetOp::GetTreeBatchSize() { + if (!child_.empty()) { + return child_[0]->GetTreeBatchSize(); + } + return 1; +} +int64_t DatasetOp::GetTreeRepeatCount() { + if (!child_.empty()) { + return child_[0]->GetTreeRepeatCount(); + } + return 1; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index 75b50c3ba6..c8204e8426 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -183,6 +183,14 @@ class DatasetOp : public std::enable_shared_from_this { /// \return Status - The status code return virtual Status GetDatasetSize(int64_t *dataset_size); + /// \brief Gets the batch size + /// \return Status - The status code return + virtual int64_t GetTreeBatchSize(); + + /// \brief Gets the repeat count + /// \return Status - The status code return + virtual int64_t GetTreeRepeatCount(); + /// \brief Performs handling for when an eoe message is received. /// The base class implementation simply flows the eoe message to output. Derived classes /// may override if they need to perform special eoe handling. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc index b8c94a01c2..1eb9eecc8e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc @@ -119,5 +119,6 @@ Status EpochCtrlOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call the pre-visitation return p->RunOnNode(shared_from_base(), modified); } +int64_t EpochCtrlOp::GetTreeRepeatCount() { return child_[0]->GetTreeRepeatCount(); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h index c494208116..36ce8f2039 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h @@ -76,6 +76,8 @@ class EpochCtrlOp : public RepeatOp { /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit Status Accept(NodePass *p, bool *modified) override; + + int64_t GetTreeRepeatCount() override; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index e8fdf1d9a0..25830a6dfb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -207,5 +207,6 @@ Status RepeatOp::GetDatasetSize(int64_t *dataset_size) { dataset_size_ = num_rows; return Status::OK(); } +int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index 50dab4bac0..74339d725a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -138,6 +138,8 @@ class RepeatOp : public PipelineOp { /// \return Status of the function Status GetDatasetSize(int64_t *dataset_size) override; + int64_t GetTreeRepeatCount() override; + // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes // \param[in] eoe_op The input leaf/eoe operator to add to the list void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index f3704b54c8..bf80a876d4 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -588,17 +588,25 @@ class Dataset : public std::enable_shared_from_this { } /// \brief Gets the dataset size - /// \return status code + /// \return int64_t int64_t GetDatasetSize(); /// \brief Gets the output type - /// \return status code + /// \return vector of DataType std::vector GetOutputTypes(); /// \brief Gets the output shape - /// \return status code + /// \return vector of TensorShapes std::vector GetOutputShapes(); + /// \brief Gets the batch size + /// \return int64_t + int64_t GetBatchSize(); + + /// \brief Gets the the repeat count + /// \return int64_t + int64_t GetRepeatCount(); + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object @@ -668,16 +676,18 @@ class Dataset : public std::enable_shared_from_this { /// 0Stop(); } +TEST_F(MindDataTestPipeline, TestGetRepeatCount) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetRepeatCount."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true); + EXPECT_NE(ds, nullptr); + EXPECT_EQ(ds->GetRepeatCount(), 1); + ds = ds->Repeat(4); + EXPECT_NE(ds, nullptr); + EXPECT_EQ(ds->GetRepeatCount(), 4); + ds = ds->Repeat(3); + EXPECT_NE(ds, nullptr); + EXPECT_EQ(ds->GetRepeatCount(), 3); +} + +TEST_F(MindDataTestPipeline, TestGetBatchSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetRepeatCount."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true)->Project({"label"}); + EXPECT_NE(ds, nullptr); + EXPECT_EQ(ds->GetBatchSize(), 1); + ds = ds->Batch(2); + EXPECT_NE(ds, nullptr); + EXPECT_EQ(ds->GetBatchSize(), 2); + ds = ds->Batch(3); + EXPECT_NE(ds, nullptr); + EXPECT_EQ(ds->GetBatchSize(), 3); +} TEST_F(MindDataTestPipeline, TestCelebAGetDatasetSize) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAGetDatasetSize.";