From df9f4f41e80547aa13ed65780b5830029808c792 Mon Sep 17 00:00:00 2001 From: YangLuo Date: Wed, 16 Dec 2020 14:54:27 +0800 Subject: [PATCH] fix get num classes of concat --- .../dataset/engine/datasetops/concat_op.cc | 15 +++++++++++++++ .../dataset/engine/datasetops/concat_op.h | 5 +++++ .../ut/python/dataset/test_datasets_manifestop.py | 15 +++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc index 9cba1119ab..a8e09514a3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -196,5 +196,20 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) { return p->PreRunOnNode(shared_from_base(), modified); } +// Gets the number of classes +Status ConcatOp::GetNumClasses(int64_t *num_classes) { + int64_t max_num_classes = -1; + for (const auto &child : child_) { + // Choose a dataset which can get valid num_classes + int64_t tmp_num_classes = -1; + child->GetNumClasses(&tmp_num_classes); + if (tmp_num_classes > max_num_classes) { + max_num_classes = tmp_num_classes; + } + } + *num_classes = max_num_classes; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h index bab503a133..ba2b1bedb7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -111,6 +111,11 @@ class ConcatOp : public PipelineOp { /// \return Status of the node visit Status PreAccept(NodePass *p, bool *modified) override; + /// \brief Gets the number of classes + /// \param[out] num_classes the number of classes + /// \return Status - The status code return + Status GetNumClasses(int64_t *num_classes) override; + private: Status Verify(int32_t id, const std::unique_ptr &buf); diff --git a/tests/ut/python/dataset/test_datasets_manifestop.py b/tests/ut/python/dataset/test_datasets_manifestop.py index f056d0edd4..3f891b8fb8 100644 --- a/tests/ut/python/dataset/test_datasets_manifestop.py +++ b/tests/ut/python/dataset/test_datasets_manifestop.py @@ -113,6 +113,20 @@ def test_manifest_dataset_multi_label_onehot(): count = count + 1 +def test_manifest_dataset_get_num_class(): + data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False) + assert data.num_classes() == 3 + + padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}] + padded_ds = ds.PaddedDataset(padded_samples) + + data = data.repeat(2) + padded_ds = padded_ds.repeat(2) + + data1 = data + padded_ds + assert data1.num_classes() == 3 + + if __name__ == '__main__': test_manifest_dataset_train() test_manifest_dataset_eval() @@ -120,3 +134,4 @@ if __name__ == '__main__': test_manifest_dataset_get_class_index() test_manifest_dataset_multi_label() test_manifest_dataset_multi_label_onehot() + test_manifest_dataset_get_num_class()