diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc index affc3600df..860dd3b3ba 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc @@ -110,5 +110,12 @@ PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) { .export_values(); })); +PYBIND_REGISTER(ImageBatchFormat, 0, ([](const py::module *m) { + (void)py::enum_(*m, "ImageBatchFormat", py::arithmetic()) + .value("DE_IMAGE_BATCH_FORMAT_NHWC", ImageBatchFormat::kNHWC) + .value("DE_IMAGE_BATCH_FORMAT_NCHW", ImageBatchFormat::kNCHW) + .export_values(); + })); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc index eafdc8a1a6..68f8a9d64d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc @@ -22,6 +22,7 @@ #include "minddata/dataset/kernels/image/auto_contrast_op.h" #include "minddata/dataset/kernels/image/bounding_box_augment_op.h" #include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/kernels/image/cutmix_batch_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/decode_op.h" #include "minddata/dataset/kernels/image/equalize_op.h" @@ -105,6 +106,13 @@ PYBIND_REGISTER(MixUpBatchOp, 1, ([](const py::module *m) { .def(py::init(), py::arg("alpha")); })); +PYBIND_REGISTER(CutMixBatchOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "CutMixBatchOp", "Tensor operation to cutmix a batch of images") + .def(py::init(), py::arg("image_batch_format"), py::arg("alpha"), + py::arg("prob")); + })); + PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) { (void)py::class_>( *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index 4eba7ac05b..bd027d1581 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -19,6 +19,7 @@ #include "minddata/dataset/kernels/image/center_crop_op.h" #include "minddata/dataset/kernels/image/crop_op.h" +#include "minddata/dataset/kernels/image/cutmix_batch_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/decode_op.h" #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" @@ -70,6 +71,16 @@ std::shared_ptr Crop(std::vector coordinates, std::vecto return op; } +// Function to create CutMixBatchOperation. +std::shared_ptr CutMixBatch(ImageBatchFormat image_batch_format, float alpha, float prob) { + auto op = std::make_shared(image_batch_format, alpha, prob); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + // Function to create CutOutOp. std::shared_ptr CutOut(int32_t length, int32_t num_patches) { auto op = std::make_shared(length, num_patches); @@ -355,6 +366,27 @@ std::shared_ptr CropOperation::Build() { return tensor_op; } +// CutMixBatchOperation +CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob) + : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {} + +bool CutMixBatchOperation::ValidateParams() { + if (alpha_ < 0) { + MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative."; + return false; + } + if (prob_ < 0 || prob_ > 1) { + MS_LOG(ERROR) << "CutMixBatch: Probability has to be between 0 and 1."; + return false; + } + return true; +} + +std::shared_ptr CutMixBatchOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(image_batch_format_, alpha_, prob_); + return tensor_op; +} + // CutOutOperation CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {} diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index 61c9711d4b..55a67c8946 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -41,6 +41,12 @@ enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 }; // Possible values for Border types enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 }; +// Possible values for Image format types in a batch +enum class ImageBatchFormat { kNHWC = 0, kNCHW = 1 }; + +// Possible values for Image format types +enum class ImageFormat { HWC = 0, CHW = 1, HW = 2 }; + // Possible interpolation modes enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 }; diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index 7b13b43d19..567a344308 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -49,6 +49,7 @@ namespace vision { // Transform Op classes (in alphabetical order) class CenterCropOperation; class CropOperation; +class CutMixBatchOperation; class CutOutOperation; class DecodeOperation; class HwcToChwOperation; @@ -86,6 +87,16 @@ std::shared_ptr CenterCrop(std::vector size); /// \return Shared pointer to the current TensorOp std::shared_ptr Crop(std::vector coordinates, std::vector size); +/// \brief Function to apply CutMix on a batch of images +/// \notes Masks a random section of each image with the corresponding part of another randomly selected image in +/// that batch +/// \param[in] image_batch_format The format of the batch +/// \param[in] alpha The hyperparameter of beta distribution (default = 1.0) +/// \param[in] prob The probability by which CutMix is applied to each image (default = 1.0) +/// \return Shared pointer to the current TensorOp +std::shared_ptr CutMixBatch(ImageBatchFormat image_batch_format, float alpha = 1.0, + float prob = 1.0); + /// \brief Function to create a CutOut TensorOp /// \notes Randomly cut (mask) out a given number of square patches from the input image /// \param[in] length Integer representing the side length of each square patch @@ -305,6 +316,22 @@ class CropOperation : public TensorOperation { std::vector size_; }; +class CutMixBatchOperation : public TensorOperation { + public: + explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha = 1.0, float prob = 1.0); + + ~CutMixBatchOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float alpha_; + float prob_; + ImageBatchFormat image_batch_format_; +}; + class CutOutOperation : public TensorOperation { public: explicit CutOutOperation(int32_t length, int32_t num_patches = 1); @@ -318,6 +345,7 @@ class CutOutOperation : public TensorOperation { private: int32_t length_; int32_t num_patches_; + ImageBatchFormat image_batch_format_; }; class DecodeOperation : public TensorOperation { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc index 29fd5ada8b..662d199ea4 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -655,7 +655,7 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr &input, TensorShape remaining({-1}); std::vector index(tensor_shape.size(), 0); if (tensor_shape.size() <= 1) { - RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack"); + RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack."); } TensorShape element_shape(std::vector(tensor_shape.begin() + 1, tensor_shape.end())); @@ -664,15 +664,48 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr &input, std::shared_ptr out; RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining)); - RETURN_IF_NOT_OK(input->CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out)); std::shared_ptr cv_out = CVTensor::AsCVTensor(std::move(out)); if (!cv_out->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor."); } output->push_back(cv_out); } return Status::OK(); } +Status BatchTensorToTensorVector(const std::shared_ptr &input, std::vector> *output) { + std::vector tensor_shape = input->shape().AsVector(); + TensorShape remaining({-1}); + std::vector index(tensor_shape.size(), 0); + if (tensor_shape.size() <= 1) { + RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack."); + } + TensorShape element_shape(std::vector(tensor_shape.begin() + 1, tensor_shape.end())); + + for (; index[0] < tensor_shape[0]; index[0]++) { + uchar *start_addr_of_index = nullptr; + std::shared_ptr out; + + RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out)); + output->push_back(out); + } + return Status::OK(); +} + +Status TensorVectorToBatchTensor(const std::vector> &input, std::shared_ptr *output) { + if (input.empty()) { + RETURN_STATUS_UNEXPECTED("TensorVectorToBatchTensor: Received an empty vector."); + } + std::vector tensor_shape = input.front()->shape().AsVector(); + tensor_shape.insert(tensor_shape.begin(), input.size()); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(tensor_shape), input.at(0)->type(), output)); + for (int i = 0; i < input.size(); i++) { + RETURN_IF_NOT_OK((*output)->InsertTensor({i}, input[i])); + } + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h index 4fba6aef95..fd66d4def3 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h @@ -158,11 +158,24 @@ Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr append); /// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional CVTensors -/// @param input[in] input tensor -/// @param output[out] output tensor -/// @return Status ok/error +/// \param input[in] input tensor +/// \param output[out] output vector of CVTensors +/// \return Status ok/error Status BatchTensorToCVTensorVector(const std::shared_ptr &input, std::vector> *output); + +/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional Tensors +/// \param input[in] input tensor +/// \param output[out] output vector of tensors +/// \return Status ok/error +Status BatchTensorToTensorVector(const std::shared_ptr &input, std::vector> *output); + +/// Convert a vector of (n-1)-dimensional Tensors to an n-dimensional Tensor +/// \param input[in] input vector of tensors +/// \param output[out] output tensor +/// \return Status ok/error +Status TensorVectorToBatchTensor(const std::vector> &input, std::shared_ptr *output); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index 20733af9c8..52cf4096d8 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -7,6 +7,7 @@ add_library(kernels-image OBJECT center_crop_op.cc crop_op.cc cut_out_op.cc + cutmix_batch_op.cc decode_op.cc equalize_op.cc hwc_to_chw_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc new file mode 100644 index 0000000000..a03d5936b4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/cutmix_batch_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +CutMixBatchOp::CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, float prob) + : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) { + rnd_.seed(GetSeed()); +} + +void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y, int *crop_width, int *crop_height) { + float cut_ratio = 1 - lam; + int cut_w = static_cast(width * cut_ratio); + int cut_h = static_cast(height * cut_ratio); + std::uniform_int_distribution width_uniform_distribution(0, width); + std::uniform_int_distribution height_uniform_distribution(0, height); + int cx = width_uniform_distribution(rnd_); + int x2, y2; + int cy = height_uniform_distribution(rnd_); + *x = std::clamp(cx - cut_w / 2, 0, width - 1); // horizontal coordinate of left side of crop box + *y = std::clamp(cy - cut_h / 2, 0, height - 1); // vertical coordinate of the top side of crop box + x2 = std::clamp(cx + cut_w / 2, 0, width - 1); // horizontal coordinate of right side of crop box + y2 = std::clamp(cy + cut_h / 2, 0, height - 1); // vertical coordinate of the bottom side of crop box + *crop_width = std::clamp(x2 - *x, 1, width - 1); + *crop_height = std::clamp(y2 - *y, 1, height - 1); +} + +Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { + if (input.size() < 2) { + RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation"); + } + + std::vector> images; + std::vector image_shape = input.at(0)->shape().AsVector(); + std::vector label_shape = input.at(1)->shape().AsVector(); + + // Check inputs + if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { + RETURN_STATUS_UNEXPECTED("You must batch before calling CutMixBatch."); + } + if (label_shape.size() != 2) { + RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch"); + } + if ((image_shape[1] != 1 && image_shape[1] != 3) && image_batch_format_ == ImageBatchFormat::kNCHW) { + RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format."); + } + if ((image_shape[3] != 1 && image_shape[3] != 3) && image_batch_format_ == ImageBatchFormat::kNHWC) { + RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format."); + } + + // Move images into a vector of Tensors + RETURN_IF_NOT_OK(BatchTensorToTensorVector(input.at(0), &images)); + + // Calculate random labels + std::vector rand_indx; + for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i); + std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_); + + std::gamma_distribution gamma_distribution(alpha_, 1); + std::uniform_real_distribution uniform_distribution(0.0, 1.0); + + // Tensor holding the output labels + std::shared_ptr out_labels; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(label_shape), DataType(DataType::DE_FLOAT32), &out_labels)); + + // Compute labels and images + for (int i = 0; i < image_shape[0]; i++) { + // Calculating lambda + // If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1) + // then x = x1 / (x1+x2) is a random variable from Beta(a1, a2) + float x1 = gamma_distribution(rnd_); + float x2 = gamma_distribution(rnd_); + float lam = x1 / (x1 + x2); + double random_number = uniform_distribution(rnd_); + if (random_number < prob_) { + int x, y, crop_width, crop_height; + float label_lam; // lambda used for labels + + // Get a random image + TensorShape remaining({-1}); + uchar *start_addr_of_index = nullptr; + std::shared_ptr rand_image; + RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}), + input.at(0)->type(), start_addr_of_index, &rand_image)); + + // Compute image + if (image_batch_format_ == ImageBatchFormat::kNHWC) { + // NHWC Format + GetCropBox(static_cast(image_shape[1]), static_cast(image_shape[2]), lam, &x, &y, &crop_width, + &crop_height); + std::shared_ptr cropped; + RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height)); + RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::HWC)); + label_lam = 1 - (crop_width * crop_height / static_cast(image_shape[1] * image_shape[2])); + + } else { + // NCHW Format + GetCropBox(static_cast(image_shape[2]), static_cast(image_shape[3]), lam, &x, &y, &crop_width, + &crop_height); + std::vector> channels; // A vector holding channels of the CHW image + std::vector> cropped_channels; // A vector holding the channels of the cropped CHW + RETURN_IF_NOT_OK(BatchTensorToTensorVector(rand_image, &channels)); + for (auto channel : channels) { + // Call crop for each single channel + std::shared_ptr cropped_channel; + RETURN_IF_NOT_OK(Crop(channel, &cropped_channel, x, y, crop_width, crop_height)); + cropped_channels.push_back(cropped_channel); + } + std::shared_ptr cropped; + // Merge channels to a single tensor + RETURN_IF_NOT_OK(TensorVectorToBatchTensor(cropped_channels, &cropped)); + + RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::CHW)); + label_lam = 1 - (crop_width * crop_height / static_cast(image_shape[2] * image_shape[3])); + } + + // Compute labels + for (int j = 0; j < label_shape[1]; j++) { + uint64_t first_value, second_value; + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j})); + RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value)); + } + } + } + + std::shared_ptr out_images; + RETURN_IF_NOT_OK(TensorVectorToBatchTensor(images, &out_images)); + + // Move the output into a TensorRow + output->push_back(out_images); + output->push_back(out_labels); + + return Status::OK(); +} + +void CutMixBatchOp::Print(std::ostream &out) const { + out << "CutMixBatchOp: " + << "image_batch_format: " << image_batch_format_ << "alpha: " << alpha_ << ", probability: " << prob_ << "\n"; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.h new file mode 100644 index 0000000000..1fb7d6de38 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CutMixBatchOp : public TensorOp { + public: + explicit CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, float prob); + + ~CutMixBatchOp() override = default; + + void Print(std::ostream &out) const override; + + void GetCropBox(int width, int height, float lam, int *x, int *y, int *crop_width, int *crop_height); + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kCutMixBatchOp; } + + private: + float alpha_; + float prob_; + ImageBatchFormat image_batch_format_; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index 5beaf81fc9..d43f33bc19 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -402,6 +402,62 @@ Status HwcToChw(std::shared_ptr input, std::shared_ptr *output) } } +Status MaskWithTensor(const std::shared_ptr &sub_mat, std::shared_ptr *input, int x, int y, + int crop_width, int crop_height, ImageFormat image_format) { + if (image_format == ImageFormat::HWC) { + if ((*input)->Rank() != 3 || ((*input)->shape()[2] != 1 && (*input)->shape()[2] != 3)) { + RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format."); + } + if (sub_mat->Rank() != 3 || (sub_mat->shape()[2] != 1 && sub_mat->shape()[2] != 3)) { + RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format."); + } + int number_of_channels = (*input)->shape()[2]; + for (int i = 0; i < crop_width; i++) { + for (int j = 0; j < crop_height; j++) { + for (int c = 0; c < number_of_channels; c++) { + uint8_t pixel_value; + RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i, c})); + RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i, c}, pixel_value)); + } + } + } + } else if (image_format == ImageFormat::CHW) { + if ((*input)->Rank() != 3 || ((*input)->shape()[0] != 1 && (*input)->shape()[0] != 3)) { + RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format."); + } + if (sub_mat->Rank() != 3 || (sub_mat->shape()[0] != 1 && sub_mat->shape()[0] != 3)) { + RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format."); + } + int number_of_channels = (*input)->shape()[0]; + for (int i = 0; i < crop_width; i++) { + for (int j = 0; j < crop_height; j++) { + for (int c = 0; c < number_of_channels; c++) { + uint8_t pixel_value; + RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {c, j, i})); + RETURN_IF_NOT_OK((*input)->SetItemAt({c, y + j, x + i}, pixel_value)); + } + } + } + } else if (image_format == ImageFormat::HW) { + if ((*input)->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format."); + } + if (sub_mat->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format."); + } + for (int i = 0; i < crop_width; i++) { + for (int j = 0; j < crop_height; j++) { + uint8_t pixel_value; + RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i})); + RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i}, pixel_value)); + } + } + } else { + RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image format must be CHW, HWC, or HW."); + } + return Status::OK(); +} + Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output) { try { std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index b736f3854e..0dfd94e14c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -120,6 +120,19 @@ Status Crop(const std::shared_ptr &input, std::shared_ptr *outpu /// \param output: Tensor of shape or and same input type. Status HwcToChw(std::shared_ptr input, std::shared_ptr *output); +/// \brief Masks the given part of the input image with a another image (sub_mat) +/// \param[in] sub_mat The image we want to mask with +/// \param[in] input The pointer to the image we want to mask +/// \param[in] x The horizontal coordinate of left side of crop box +/// \param[in] y The vertical coordinate of the top side of crop box +/// \param[in] width The width of the mask box +/// \param[in] height The height of the mask box +/// \param[in] image_format The format of the image (CHW or HWC) +/// \param[out] input Masks the input image in-place and returns it +/// @return Status ok/error +Status MaskWithTensor(const std::shared_ptr &sub_mat, std::shared_ptr *input, int x, int y, int width, + int height, ImageFormat image_format); + /// \brief Swap the red and blue pixels (RGB <-> BGR) /// \param input: Tensor of shape and any OpenCv compatible type, see CVTensor. /// \param output: Swapped image of same shape and type diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc index 6386ba121f..c3195a43f4 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc @@ -37,10 +37,12 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) { std::vector label_shape = input.at(1)->shape().AsVector(); // Check inputs - if (label_shape.size() != 2 || image_shape.size() != 4 || image_shape[0] != label_shape[0]) { + if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch"); } - + if (label_shape.size() != 2) { + RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch"); + } if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) { RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW"); } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 9d7068f7cc..a365c7aba4 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -94,6 +94,7 @@ constexpr char kAutoContrastOp[] = "AutoContrastOp"; constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kDecodeOp[] = "DecodeOp"; constexpr char kCenterCropOp[] = "CenterCropOp"; +constexpr char kCutMixBatchOp[] = "CutMixBatchOp"; constexpr char kCutOutOp[] = "CutOutOp"; constexpr char kCropOp[] = "CropOp"; constexpr char kEqualizeOp[] = "EqualizeOp"; diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index d7c0033c84..fc7ad571c5 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -43,13 +43,14 @@ Examples: import numbers import mindspore._c_dataengine as cde -from .utils import Inter, Border +from .utils import Inter, Border, ImageBatchFormat from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_range, check_resize, check_rescale, check_pad, check_cutout, \ check_uniform_augment_cpp, \ check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ - check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER + check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \ + check_cut_mix_batch_c DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -60,6 +61,8 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT, Border.REFLECT: cde.BorderType.DE_BORDER_REFLECT, Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} +DE_C_IMAGE_BATCH_FORMAT = {ImageBatchFormat.NHWC: cde.ImageBatchFormat.DE_IMAGE_BATCH_FORMAT_NHWC, + ImageBatchFormat.NCHW: cde.ImageBatchFormat.DE_IMAGE_BATCH_FORMAT_NCHW} def parse_padding(padding): if isinstance(padding, numbers.Number): @@ -143,6 +146,33 @@ class Decode(cde.DecodeOp): super().__init__(self.rgb) +class CutMixBatch(cde.CutMixBatchOp): + """ + Apply CutMix transformation on input batch of images and labels. + Note that you need to make labels into one-hot format and batch before calling this function. + + Args: + image_batch_format (Image Batch Format): The method of padding. Can be any of + [ImageBatchFormat.NHWC, ImageBatchFormat.NCHW] + alpha (float): hyperparameter of beta distribution (default = 1.0). + prob (float): The probability by which CutMix is applied to each image (default = 1.0). + + Examples: + >>> one_hot_op = data.OneHot(num_classes=10) + >>> data = data.map(input_columns=["label"], operations=one_hot_op) + >>> cutmix_batch_op = vision.CutMixBatch(ImageBatchFormat.NHWC, 1.0, 0.5) + >>> data = data.batch(5) + >>> data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op) + """ + + @check_cut_mix_batch_c + def __init__(self, image_batch_format, alpha=1.0, prob=1.0): + self.image_batch_format = image_batch_format.value + self.alpha = alpha + self.prob = prob + super().__init__(DE_C_IMAGE_BATCH_FORMAT[image_batch_format], alpha, prob) + + class CutOut(cde.CutOutOp): """ Randomly cut (mask) out a given number of square patches from the input Numpy image array. diff --git a/mindspore/dataset/transforms/vision/utils.py b/mindspore/dataset/transforms/vision/utils.py index 1abc66a467..223db6352c 100644 --- a/mindspore/dataset/transforms/vision/utils.py +++ b/mindspore/dataset/transforms/vision/utils.py @@ -30,3 +30,9 @@ class Border(str, Enum): EDGE: str = "edge" REFLECT: str = "reflect" SYMMETRIC: str = "symmetric" + + +# Image Batch Format +class ImageBatchFormat(IntEnum): + NHWC = 0 + NCHW = 1 diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index a4badd0b47..cc95c93162 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -19,7 +19,7 @@ from functools import wraps import numpy as np from mindspore._c_dataengine import TensorOp -from .utils import Inter, Border +from .utils import Inter, Border, ImageBatchFormat from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \ check_tensor_op, UINT8_MAX @@ -37,6 +37,20 @@ def check_crop_size(size): raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") +def check_cut_mix_batch_c(method): + """Wrapper method to check the parameters of CutMixBatch.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs) + type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format") + check_pos_float32(alpha) + check_value(prob, [0, 1], "prob") + return method(self, *args, **kwargs) + + return new_method + + def check_resize_size(size): """Wrapper method to check the parameters of resize.""" if isinstance(size, int): diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index a6b16370db..4a62e5208f 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -20,6 +20,7 @@ SET(DE_UT_SRCS circular_pool_test.cc client_config_test.cc connector_test.cc + cutmix_batch_op_test.cc cut_out_op_test.cc datatype_test.cc decode_op_test.cc diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 50183641db..57a562f94e 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -25,6 +25,177 @@ class MindDataTestPipeline : public UT::DatasetOpTesting { protected: }; +TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) { + // Testing CutMixBatch on a batch of CHW images + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + int number_of_classes = 10; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr hwc_to_chw = vision::HWC2CHW(); + EXPECT_NE(hwc_to_chw, nullptr); + + // Create a Map operation on ds + ds = ds->Map({hwc_to_chw},{"image"}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(number_of_classes); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0); + EXPECT_NE(cutmix_batch_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({cutmix_batch_op}, {"image", "label"}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + auto label = row["label"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + MS_LOG(INFO) << "Label shape: " << label->shape(); + EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1] + && 32 == image->shape()[2] && 32 == image->shape()[3], true); + EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && + number_of_classes == label->shape()[1], true); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) { + // Calling CutMixBatch on a batch of HWC images with default values of alpha and prob + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + int number_of_classes = 10; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(number_of_classes); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC); + EXPECT_NE(cutmix_batch_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({cutmix_batch_op}, {"image", "label"}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + auto label = row["label"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + MS_LOG(INFO) << "Label shape: " << label->shape(); + EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1] + && 32 == image->shape()[2] && 3 == image->shape()[3], true); + EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && + number_of_classes == label->shape()[1], true); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) { + // Must fail because alpha can't be negative + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(10); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5); + EXPECT_EQ(cutmix_batch_op, nullptr); +} + +TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) { + // Must fail because prob can't be negative + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(10); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5); + EXPECT_EQ(cutmix_batch_op, nullptr); + +} + TEST_F(MindDataTestPipeline, TestCutOut) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; diff --git a/tests/ut/cpp/dataset/cutmix_batch_op_test.cc b/tests/ut/cpp/dataset/cutmix_batch_op_test.cc new file mode 100644 index 0000000000..1e927ed788 --- /dev/null +++ b/tests/ut/cpp/dataset/cutmix_batch_op_test.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "common/cvop_common.h" +#include "minddata/dataset/kernels/image/cutmix_batch_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +class MindDataTestCutMixBatchOp : public UT::CVOP::CVOpCommon { + protected: + MindDataTestCutMixBatchOp() : CVOpCommon() {} +}; + +TEST_F(MindDataTestCutMixBatchOp, TestSuccess1) { + MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp success1 case"; + std::shared_ptr batched_tensor; + std::shared_ptr batched_labels; + Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}), + input_tensor_->type(), &batched_tensor); + for (int i = 0; i < 2; i++) { + batched_tensor->InsertTensor({i}, input_tensor_); + } + Tensor::CreateFromVector(std::vector({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels); + std::shared_ptr op = std::make_shared(ImageBatchFormat::kNHWC, 1.0, 1.0); + TensorRow in; + in.push_back(batched_tensor); + in.push_back(batched_labels); + TensorRow out; + ASSERT_TRUE(op->Compute(in, &out).IsOk()); + + EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]); + EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]); + EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]); + EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]); + + EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]); + EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]); +} + +TEST_F(MindDataTestCutMixBatchOp, TestSuccess2) { + MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp success2 case"; + std::shared_ptr batched_tensor; + std::shared_ptr batched_labels; + std::shared_ptr chw_tensor; + ASSERT_TRUE(HwcToChw(input_tensor_, &chw_tensor).IsOk()); + Tensor::CreateEmpty(TensorShape({2, chw_tensor->shape()[0], chw_tensor->shape()[1], chw_tensor->shape()[2]}), + chw_tensor->type(), &batched_tensor); + for (int i = 0; i < 2; i++) { + batched_tensor->InsertTensor({i}, chw_tensor); + } + Tensor::CreateFromVector(std::vector({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels); + std::shared_ptr op = std::make_shared(ImageBatchFormat::kNCHW, 1.0, 0.5); + TensorRow in; + in.push_back(batched_tensor); + in.push_back(batched_labels); + TensorRow out; + ASSERT_TRUE(op->Compute(in, &out).IsOk()); + + EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]); + EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]); + EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]); + EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]); + + EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]); + EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]); +} + +TEST_F(MindDataTestCutMixBatchOp, TestFail1) { + // This is a fail case because our labels are not batched and are 1-dimensional + MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp fail1 case"; + std::shared_ptr labels; + Tensor::CreateFromVector(std::vector({0, 1, 1, 0}), TensorShape({4}), &labels); + std::shared_ptr op = std::make_shared(ImageBatchFormat::kNHWC, 1.0, 1.0); + TensorRow in; + in.push_back(input_tensor_); + in.push_back(labels); + TensorRow out; + ASSERT_FALSE(op->Compute(in, &out).IsOk()); +} + +TEST_F(MindDataTestCutMixBatchOp, TestFail2) { + // This should fail because the image_batch_format provided is not the same as the actual format of the images + MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp fail2 case"; + std::shared_ptr batched_tensor; + std::shared_ptr batched_labels; + Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}), + input_tensor_->type(), &batched_tensor); + for (int i = 0; i < 2; i++) { + batched_tensor->InsertTensor({i}, input_tensor_); + } + Tensor::CreateFromVector(std::vector({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels); + std::shared_ptr op = std::make_shared(ImageBatchFormat::kNCHW, 1.0, 1.0); + TensorRow in; + in.push_back(batched_tensor); + in.push_back(batched_labels); + TensorRow out; + ASSERT_FALSE(op->Compute(in, &out).IsOk()); +} diff --git a/tests/ut/data/dataset/golden/cutmix_batch_c_nchw_result.npz b/tests/ut/data/dataset/golden/cutmix_batch_c_nchw_result.npz new file mode 100644 index 0000000000..65c2e9a156 Binary files /dev/null and b/tests/ut/data/dataset/golden/cutmix_batch_c_nchw_result.npz differ diff --git a/tests/ut/data/dataset/golden/cutmix_batch_c_nhwc_result.npz b/tests/ut/data/dataset/golden/cutmix_batch_c_nhwc_result.npz new file mode 100644 index 0000000000..89c8f8733f Binary files /dev/null and b/tests/ut/data/dataset/golden/cutmix_batch_c_nhwc_result.npz differ diff --git a/tests/ut/python/dataset/test_cutmix_batch_op.py b/tests/ut/python/dataset/test_cutmix_batch_op.py new file mode 100644 index 0000000000..d283c27d13 --- /dev/null +++ b/tests/ut/python/dataset/test_cutmix_batch_op.py @@ -0,0 +1,336 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing the CutMixBatch op in DE +""" +import numpy as np +import pytest +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +import mindspore.dataset.transforms.c_transforms as data_trans +import mindspore.dataset.transforms.vision.utils as mode +from mindspore import log as logger +from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \ + config_get_set_num_parallel_workers + +DATA_DIR = "../data/dataset/testCifar10Data" + +GENERATE_GOLDEN = False + + +def test_cutmix_batch_success1(plot=False): + """ + Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images + """ + logger.info("test_cutmix_batch_success1") + + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5, drop_remainder=True) + + images_original = None + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # CutMix Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + hwc2chw_op = vision.HWC2CHW() + data1 = data1.map(input_columns=["image"], operations=hwc2chw_op) + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5) + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) + + images_cutmix = None + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_cutmix = image.transpose(0, 2, 3, 1) + else: + images_cutmix = np.append(images_cutmix, image.transpose(0, 2, 3, 1), axis=0) + if plot: + visualize_list(images_original, images_cutmix) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_cutmix[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_cutmix_batch_success2(plot=False): + """ + Test CutMixBatch op with default values for alpha and prob on a batch of HWC images + """ + logger.info("test_cutmix_batch_success2") + + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5, drop_remainder=True) + + images_original = None + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # CutMix Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) + + images_cutmix = None + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_cutmix = image + else: + images_cutmix = np.append(images_cutmix, image, axis=0) + if plot: + visualize_list(images_original, images_cutmix) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_cutmix[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_cutmix_batch_nhwc_md5(): + """ + Test CutMixBatch on a batch of HWC images with MD5: + """ + logger.info("test_cutmix_batch_nhwc_md5") + original_seed = config_get_set_seed(0) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # CutMixBatch Images + data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data = data.map(input_columns=["label"], operations=one_hot_op) + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) + data = data.batch(5, drop_remainder=True) + data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op) + + filename = "cutmix_batch_c_nhwc_result.npz" + save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) + + # Restore config setting + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_cutmix_batch_nchw_md5(): + """ + Test CutMixBatch on a batch of CHW images with MD5: + """ + logger.info("test_cutmix_batch_nchw_md5") + original_seed = config_get_set_seed(0) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # CutMixBatch Images + data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + hwc2chw_op = vision.HWC2CHW() + data = data.map(input_columns=["image"], operations=hwc2chw_op) + one_hot_op = data_trans.OneHot(num_classes=10) + data = data.map(input_columns=["label"], operations=one_hot_op) + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) + data = data.batch(5, drop_remainder=True) + data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op) + + filename = "cutmix_batch_c_nchw_result.npz" + save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) + + # Restore config setting + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_cutmix_batch_fail1(): + """ + Test CutMixBatch Fail 1 + We expect this to fail because the images and labels are not batched + """ + logger.info("test_cutmix_batch_fail1") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) + with pytest.raises(RuntimeError) as error: + data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_cutmix = image + else: + images_cutmix = np.append(images_cutmix, image, axis=0) + error_message = "You must batch before calling CutMixBatch" + assert error_message in str(error.value) + + +def test_cutmix_batch_fail2(): + """ + Test CutMixBatch Fail 2 + We expect this to fail because alpha is negative + """ + logger.info("test_cutmix_batch_fail2") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + with pytest.raises(ValueError) as error: + vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1) + error_message = "Input is not within the required interval" + assert error_message in str(error.value) + + +def test_cutmix_batch_fail3(): + """ + Test CutMixBatch Fail 2 + We expect this to fail because prob is larger than 1 + """ + logger.info("test_cutmix_batch_fail3") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + with pytest.raises(ValueError) as error: + vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2) + error_message = "Input is not within the required interval" + assert error_message in str(error.value) + + +def test_cutmix_batch_fail4(): + """ + Test CutMixBatch Fail 2 + We expect this to fail because prob is negative + """ + logger.info("test_cutmix_batch_fail4") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + with pytest.raises(ValueError) as error: + vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1) + error_message = "Input is not within the required interval" + assert error_message in str(error.value) + + +def test_cutmix_batch_fail5(): + """ + Test CutMixBatch op + We expect this to fail because label column is not passed to cutmix_batch + """ + logger.info("test_cutmix_batch_fail5") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image"], operations=cutmix_batch_op) + + with pytest.raises(RuntimeError) as error: + images_cutmix = np.array([]) + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_cutmix = image + else: + images_cutmix = np.append(images_cutmix, image, axis=0) + error_message = "Both images and labels columns are required" + assert error_message in str(error.value) + + +def test_cutmix_batch_fail6(): + """ + Test CutMixBatch op + We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images + """ + logger.info("test_cutmix_batch_fail6") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) + + with pytest.raises(RuntimeError) as error: + images_cutmix = np.array([]) + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_cutmix = image + else: + images_cutmix = np.append(images_cutmix, image, axis=0) + error_message = "CutMixBatch: Image doesn't match the given image format." + assert error_message in str(error.value) + + +def test_cutmix_batch_fail7(): + """ + Test CutMixBatch op + We expect this to fail because labels are not in one-hot format + """ + logger.info("test_cutmix_batch_fail7") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) + + with pytest.raises(RuntimeError) as error: + images_cutmix = np.array([]) + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_cutmix = image + else: + images_cutmix = np.append(images_cutmix, image, axis=0) + error_message = "CutMixBatch: Label's must be in one-hot format and in a batch" + assert error_message in str(error.value) + + +if __name__ == "__main__": + test_cutmix_batch_success1(plot=True) + test_cutmix_batch_success2(plot=True) + test_cutmix_batch_nchw_md5() + test_cutmix_batch_nhwc_md5() + test_cutmix_batch_fail1() + test_cutmix_batch_fail2() + test_cutmix_batch_fail3() + test_cutmix_batch_fail4() + test_cutmix_batch_fail5() + test_cutmix_batch_fail6() + test_cutmix_batch_fail7()