diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index d16a7c46fa..b1d80db9dd 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -21,6 +21,7 @@ #include "minddata/dataset/kernels/image/crop_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" #include "minddata/dataset/kernels/image/mixup_batch_op.h" #include "minddata/dataset/kernels/image/normalize_op.h" #include "minddata/dataset/kernels/data/one_hot_op.h" @@ -84,6 +85,16 @@ std::shared_ptr Decode(bool rgb) { return op; } +// Function to create HwcToChwOperation. +std::shared_ptr HWC2CHW() { + auto op = std::make_shared(); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + // Function to create MixUpBatchOperation. std::shared_ptr MixUpBatch(float alpha) { auto op = std::make_shared(alpha); @@ -311,6 +322,11 @@ bool DecodeOperation::ValidateParams() { return true; } std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } +// HwcToChwOperation +bool HwcToChwOperation::ValidateParams() { return true; } + +std::shared_ptr HwcToChwOperation::Build() { return std::make_shared(); } + // MixUpOperation MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index 967fe1b57b..b33516486b 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -51,6 +51,7 @@ class CenterCropOperation; class CropOperation; class CutOutOperation; class DecodeOperation; +class HwcToChwOperation; class MixUpBatchOperation; class NormalizeOperation; class OneHotOperation; @@ -93,6 +94,11 @@ std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1) /// \return Shared pointer to the current TensorOperation. std::shared_ptr Decode(bool rgb = true); +/// \brief Function to create a HwcToChw TensorOperation. +/// \notes Transpose the input image; shape (H, W, C) to shape (C, H, W). +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr HWC2CHW(); + /// \brief Function to create a MixUpBatch TensorOperation. /// \notes Apply MixUp transformation on an input batch of images and labels. The labels must be in one-hot format and /// Batch must be called before calling this function. @@ -291,6 +297,15 @@ class DecodeOperation : public TensorOperation { bool rgb_; }; +class HwcToChwOperation : public TensorOperation { + public: + ~HwcToChwOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; +}; + class MixUpBatchOperation : public TensorOperation { public: explicit MixUpBatchOperation(float alpha = 1); diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index c98b56c4cc..06ca76d70f 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -463,6 +463,55 @@ TEST_F(MindDataTestPipeline, TestDecode) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestHwcToChw) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr channel_swap = vision::HWC2CHW(); + EXPECT_NE(channel_swap, nullptr); + + // Create a Map operation on ds + ds = ds->Map({channel_swap}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + // check if the image is in NCHW + EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] + && 2268 == image->shape()[2] && 4032 == image->shape()[3], true); + iter->GetNextRow(&row); + } + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/";