From d328d6eb54325685d4530747b9c2bfca396de2f0 Mon Sep 17 00:00:00 2001 From: Eric Date: Tue, 16 Feb 2021 23:48:22 -0500 Subject: [PATCH] Adding affine Need to move Affine to image_utilC Refactored affine Added affine API Added API class impl Removed default from vision ir Compiling removed extra file Spell check --- .../dataset/kernels/ir/image/bindings.cc | 4 +- .../ccsrc/minddata/dataset/api/vision.cc | 13 ++ .../ccsrc/minddata/dataset/include/vision.h | 43 ------- .../minddata/dataset/include/vision_lite.h | 77 ++++++++++++ .../dataset/kernels/image/affine_op.cc | 95 +++++++-------- .../dataset/kernels/image/affine_op.h | 7 +- .../dataset/kernels/image/image_utils.cc | 20 ++++ .../dataset/kernels/image/image_utils.h | 11 ++ .../dataset/kernels/image/lite_image_utils.cc | 41 +++++++ .../dataset/kernels/image/lite_image_utils.h | 83 +++++++------ .../dataset/kernels/image/math_utils.cc | 2 - .../dataset/kernels/image/math_utils.h | 2 + .../dataset/kernels/image/random_affine_op.cc | 4 + .../dataset/kernels/image/random_affine_op.h | 2 - .../dataset/kernels/ir/vision/vision_ir.cc | 58 +++++++++ .../dataset/kernels/ir/vision/vision_ir.h | 99 ++++++++------- .../minddata/dataset/kernels/tensor_op.h | 2 + mindspore/dataset/core/config.py | 2 +- mindspore/lite/minddata/CMakeLists.txt | 5 +- tests/ut/cpp/dataset/CMakeLists.txt | 2 + tests/ut/cpp/dataset/affine_op_test.cc | 113 ++++++++++++++++++ tests/ut/cpp/dataset/c_api_affine_test.cc | 97 +++++++++++++++ 22 files changed, 595 insertions(+), 187 deletions(-) create mode 100644 tests/ut/cpp/dataset/affine_op_test.cc create mode 100644 tests/ut/cpp/dataset/c_api_affine_test.cc diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc index fb5a8d054b..e33d1b1d8d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc @@ -84,8 +84,8 @@ PYBIND_REGISTER(CutOutOperation, 1, ([](const py::module *m) { PYBIND_REGISTER(DecodeOperation, 1, ([](const py::module *m) { (void)py::class_>( *m, "DecodeOperation") - .def(py::init([]() { - auto decode = std::make_shared(); + .def(py::init([](bool rgb) { + auto decode = std::make_shared(rgb); THROW_IF_ERROR(decode->ValidateParams()); return decode; })) diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index a3f40c2466..1e1863e223 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -47,6 +47,19 @@ namespace vision { // FUNCTIONS TO CREATE VISION TRANSFORM OPERATIONS // (In alphabetical order) +Affine::Affine(float_t degrees, const std::vector &translation, float scale, const std::vector &shear, + InterpolationMode interpolation, const std::vector &fill_value) + : degrees_(degrees), + translation_(translation), + scale_(scale), + shear_(shear), + interpolation_(interpolation), + fill_value_(fill_value) {} + +std::shared_ptr Affine::Parse() { + return std::make_shared(degrees_, translation_, scale_, shear_, interpolation_, fill_value_); +} + // AutoContrast Transform Operation. AutoContrast::AutoContrast(float cutoff, std::vector ignore) : cutoff_(cutoff), ignore_(ignore) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/vision.h b/mindspore/ccsrc/minddata/dataset/include/vision.h index 1453c69158..2a39b9fda0 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision.h @@ -35,7 +35,6 @@ class TensorOperation; // Transform operations for performing computer vision. namespace vision { - /// \brief AutoContrast TensorTransform. /// \notes Apply automatic contrast on input image. class AutoContrast : public TensorTransform { @@ -253,48 +252,6 @@ class Pad : public TensorTransform { BorderType padding_mode_; }; -/// \brief RandomAffine TensorTransform. -/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode. -class RandomAffine : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] degrees A float vector of size 2, representing the starting and ending degree - /// \param[in] translate_range A float vector of size 2 or 4, representing percentages of translation on x and y axes. - /// if size is 2, (min_dx, max_dx, 0, 0) - /// if size is 4, (min_dx, max_dx, min_dy, max_dy) - /// all values are in range [-1, 1] - /// \param[in] scale_range A float vector of size 2, representing the starting and ending scales in the range. - /// \param[in] shear_ranges A float vector of size 2 or 4, representing the starting and ending shear degrees - /// vertically and horizontally. - /// if size is 2, (min_shear_x, max_shear_x, 0, 0) - /// if size is 4, (min_shear_x, max_shear_x, min_shear_y, max_shear_y) - /// \param[in] interpolation An enum for the mode of interpolation - /// \param[in] fill_value A vector representing the value to fill the area outside the transform - /// in the output image. If 1 value is provided, it is used for all RGB channels. - /// If 3 values are provided, it is used to fill R, G, B channels respectively. - explicit RandomAffine(const std::vector °rees, - const std::vector &translate_range = {0.0, 0.0, 0.0, 0.0}, - const std::vector &scale_range = {1.0, 1.0}, - const std::vector &shear_ranges = {0.0, 0.0, 0.0, 0.0}, - InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, - const std::vector &fill_value = {0, 0, 0}); - - /// \brief Destructor. - ~RandomAffine() = default; - - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - std::vector degrees_; // min_degree, max_degree - std::vector translate_range_; // maximum x translation percentage, maximum y translation percentage - std::vector scale_range_; // min_scale, max_scale - std::vector shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear - InterpolationMode interpolation_; - std::vector fill_value_; -}; - /// \brief Blends an image with its grayscale version with random weights /// t and 1 - t generated from a given range. If the range is trivial /// then the weights are determinate and t equals the bound of the interval diff --git a/mindspore/ccsrc/minddata/dataset/include/vision_lite.h b/mindspore/ccsrc/minddata/dataset/include/vision_lite.h index e2cc616c43..de2c0ae4d2 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision_lite.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision_lite.h @@ -35,6 +35,41 @@ namespace vision { // Forward Declarations class RotateOperation; +/// \brief Affine TensorTransform. +/// \notes Apply affine transform on input image. +class Affine : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] degrees The degrees to rotate the image by + /// \param[in] translation The value representing vertical and horizontal translation (default = {0.0, 0.0}) + /// The first value represent the x axis translation while the second represents y axis translation. + /// \param[in] scale The scaling factor for the image (default = 0.0) + /// \param[in] shear A float vector of size 2, representing the shear degrees (default = {0.0, 0.0}) + /// \param[in] interpolation An enum for the mode of interpolation + /// \param[in] fill_value A vector representing the value to fill the area outside the transform + /// in the output image. If 1 value is provided, it is used for all RGB channels. + /// If 3 values are provided, it is used to fill R, G, B channels respectively. + explicit Affine(float_t degrees, const std::vector &translation = {0.0, 0.0}, float scale = 0.0, + const std::vector &shear = {0.0, 0.0}, + InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, + const std::vector &fill_value = {0, 0, 0}); + + /// \brief Destructor. + ~Affine() = default; + + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + float degrees_; + std::vector translation_; + float scale_; + std::vector shear_; + InterpolationMode interpolation_; + std::vector fill_value_; +}; + /// \brief CenterCrop TensorTransform. /// \notes Crops the input image at the center to the given size. class CenterCrop : public TensorTransform { @@ -121,6 +156,48 @@ class Normalize : public TensorTransform { std::vector std_; }; +/// \brief RandomAffine TensorTransform. +/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode. +class RandomAffine : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] degrees A float vector of size 2, representing the starting and ending degree + /// \param[in] translate_range A float vector of size 2 or 4, representing percentages of translation on x and y axes. + /// if size is 2, (min_dx, max_dx, 0, 0) + /// if size is 4, (min_dx, max_dx, min_dy, max_dy) + /// all values are in range [-1, 1] + /// \param[in] scale_range A float vector of size 2, representing the starting and ending scales in the range. + /// \param[in] shear_ranges A float vector of size 2 or 4, representing the starting and ending shear degrees + /// vertically and horizontally. + /// if size is 2, (min_shear_x, max_shear_x, 0, 0) + /// if size is 4, (min_shear_x, max_shear_x, min_shear_y, max_shear_y) + /// \param[in] interpolation An enum for the mode of interpolation + /// \param[in] fill_value A vector representing the value to fill the area outside the transform + /// in the output image. If 1 value is provided, it is used for all RGB channels. + /// If 3 values are provided, it is used to fill R, G, B channels respectively. + explicit RandomAffine(const std::vector °rees, + const std::vector &translate_range = {0.0, 0.0, 0.0, 0.0}, + const std::vector &scale_range = {1.0, 1.0}, + const std::vector &shear_ranges = {0.0, 0.0, 0.0, 0.0}, + InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, + const std::vector &fill_value = {0, 0, 0}); + + /// \brief Destructor. + ~RandomAffine() = default; + + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + std::vector degrees_; // min_degree, max_degree + std::vector translate_range_; // maximum x translation percentage, maximum y translation percentage + std::vector scale_range_; // min_scale, max_scale + std::vector shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear + InterpolationMode interpolation_; + std::vector fill_value_; +}; + /// \brief Resize TensorTransform. /// \notes Resize the input image to the given size. class Resize : public TensorTransform { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.cc index 78bb6fa7c4..a34f9dec7c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,11 @@ #include #include "minddata/dataset/kernels/image/affine_op.h" +#ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/image_utils.h" +#else +#include "minddata/dataset/kernels/image/lite_image_utils.h" +#endif #include "minddata/dataset/kernels/image/math_utils.h" #include "minddata/dataset/util/random.h" @@ -45,59 +49,46 @@ AffineOp::AffineOp(float_t degrees, const std::vector &translation, flo Status AffineOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); - try { - float_t translation_x = translation_[0]; - float_t translation_y = translation_[1]; - float_t degrees = 0.0; - DegreesToRadians(degrees_, °rees); - float_t shear_x = shear_[0]; - float_t shear_y = shear_[1]; - DegreesToRadians(shear_x, &shear_x); - DegreesToRadians(-1 * shear_y, &shear_y); - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + float_t translation_x = translation_[0]; + float_t translation_y = translation_[1]; + float_t degrees = 0.0; + DegreesToRadians(degrees_, °rees); + float_t shear_x = shear_[0]; + float_t shear_y = shear_[1]; + DegreesToRadians(shear_x, &shear_x); + DegreesToRadians(-1 * shear_y, &shear_y); - // Apply Affine Transformation - // T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] - // C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] - // RSS is rotation with scale and shear matrix - // RSS(a, s, (sx, sy)) = - // = R(a) * S(s) * SHy(sy) * SHx(sx) - // = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ] - // [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ] - // [ 0 , 0 , 1 ] - // - // where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: - // SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] - // [0, 1 ] [-tan(s), 1] - // - // Thus, the affine matrix is M = T * C * RSS * C^-1 + // Apply Affine Transformation + // T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + // C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + // RSS is rotation with scale and shear matrix + // RSS(a, s, (sx, sy)) = + // = R(a) * S(s) * SHy(sy) * SHx(sx) + // = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ] + // [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ] + // [ 0 , 0 , 1 ] + // + // where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: + // SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] + // [0, 1 ] [-tan(s), 1] + // + // Thus, the affine matrix is M = T * C * RSS * C^-1 - float_t cx = ((input_cv->mat().cols - 1) / 2.0); - float_t cy = ((input_cv->mat().rows - 1) / 2.0); - // Calculate RSS - std::vector matrix{ - static_cast(scale_ * cos(degrees + shear_y) / cos(shear_y)), - static_cast(scale_ * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees))), - 0, - static_cast(scale_ * sin(degrees + shear_y) / cos(shear_y)), - static_cast(scale_ * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees))), - 0}; - // Compute T * C * RSS * C^-1 - matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x; - matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y; - cv::Mat affine_mat(matrix); - affine_mat = affine_mat.reshape(1, {2, 3}); - - std::shared_ptr output_cv; - RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::warpAffine(input_cv->mat(), output_cv->mat(), affine_mat, input_cv->mat().size(), - GetCVInterpolationMode(interpolation_), cv::BORDER_CONSTANT, - cv::Scalar(fill_value_[0], fill_value_[1], fill_value_[2])); - (*output) = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Affine: " + std::string(e.what())); - } + // image is hwc, rows = shape()[0] + float_t cx = ((input->shape()[1] - 1) / 2.0); + float_t cy = ((input->shape()[0] - 1) / 2.0); + // Calculate RSS + std::vector matrix{ + static_cast(scale_ * cos(degrees + shear_y) / cos(shear_y)), + static_cast(scale_ * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees))), + 0, + static_cast(scale_ * sin(degrees + shear_y) / cos(shear_y)), + static_cast(scale_ * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees))), + 0}; + // Compute T * C * RSS * C^-1 + matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x; + matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y; + RETURN_IF_NOT_OK(Affine(input, output, matrix, interpolation_, fill_value_[0], fill_value_[1], fill_value_[2])); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.h index 947125a336..cb178d6162 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/affine_op.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ #include #include -#include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/util/status.h" @@ -49,10 +48,6 @@ class AffineOp : public TensorOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - /// Member variables - private: - std::string kAffineOp = "AffineOp"; - protected: float_t degrees_; std::vector translation_; // translation_x and translation_y diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index dc09a5d843..e86a8c0fd2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -1101,5 +1101,25 @@ Status GetJpegImageInfo(const std::shared_ptr &input, int *img_width, in jpeg_destroy_decompress(&cinfo); return Status::OK(); } + +Status Affine(const std::shared_ptr &input, std::shared_ptr *output, const std::vector &mat, + InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat affine_mat(mat); + affine_mat = affine_mat.reshape(1, {2, 3}); + + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::warpAffine(input_cv->mat(), output_cv->mat(), affine_mat, input_cv->mat().size(), + GetCVInterpolationMode(interpolation), cv::BORDER_CONSTANT, cv::Scalar(fill_r, fill_g, fill_b)); + (*output) = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Affine: " + std::string(e.what())); + } +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index dab2234587..518cdee0ab 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -299,6 +299,17 @@ Status RgbaToBgr(const std::shared_ptr &input, std::shared_ptr * /// \param img_height: the jpeg image height Status GetJpegImageInfo(const std::shared_ptr &input, int *img_width, int *img_height); +/// \brief Geometrically transform the input image +/// \param[in] input Input Tensor +/// \param[out] output Transformed Tensor +/// \param[in] mat The transformation matrix +/// \param[in] interpolation The interpolation mode +/// \param[in] fill_r Red fill value for pad +/// \param[in] fill_g Green fill value for pad +/// \param[in] fill_b Blue fill value for pad +Status Affine(const std::shared_ptr &input, std::shared_ptr *output, const std::vector &mat, + InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc index 782b60cd93..4e399e8413 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc @@ -621,5 +621,46 @@ Status Rotate(const std::shared_ptr &input, std::shared_ptr *out return RotateAngleWithMirror(input, output, orientation); } } + +Status Affine(const std::shared_ptr &input, std::shared_ptr *output, const std::vector &mat, + InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { + try { + if (interpolation != InterpolationMode::kLinear) { + MS_LOG(WARNING) << "Only Bilinear interpolation supported for now"; + } + int height = 0; + int width = 0; + double M[6] = {}; + for (int i = 0; i < mat.size(); i++) { + M[i] = static_cast(mat[i]); + } + + LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2], + const_cast(reinterpret_cast(input->GetBuffer())), + GetLiteCVDataType(input->type())); + + height = lite_mat_rgb.height_; + width = lite_mat_rgb.width_; + std::vector dsize; + dsize.push_back(width); + dsize.push_back(height); + LiteMat lite_mat_affine; + std::shared_ptr output_tensor; + TensorShape new_shape = TensorShape({height, width, input->shape()[2]}); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor)); + uint8_t *buffer = reinterpret_cast(&(*output_tensor->begin())); + lite_mat_affine.Init(width, height, lite_mat_rgb.channel_, reinterpret_cast(buffer), + GetLiteCVDataType(input->type())); + + bool ret = Affine(lite_mat_rgb, lite_mat_affine, M, dsize, UINT8_C3(fill_r, fill_g, fill_b)); + CHECK_FAIL_RETURN_UNEXPECTED(ret, "Affine: affine failed."); + + *output = output_tensor; + return Status::OK(); + } catch (std::runtime_error &e) { + RETURN_STATUS_UNEXPECTED("Affine: " + std::string(e.what())); + } +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h index 72276cc7fb..c89fe321cf 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,70 +52,81 @@ Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr or and any OpenCv compatible type, see CVTensor. -/// \param x: starting horizontal position of ROI -/// \param y: starting vertical position of ROI -/// \param w: width of the ROI -/// \param h: height of the ROI -/// \param output: Cropped image Tensor of shape or and same input type. +/// \param[in] input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +/// \param[in] x Starting horizontal position of ROI +/// \param[in] y Starting vertical position of ROI +/// \param[in] w Width of the ROI +/// \param[in] h Height of the ROI +/// \param[out] output: Cropped image Tensor of shape or and same input type. Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h); /// \brief Returns Decoded image /// Supported images: /// BMP JPEG JPG PNG TIFF /// supported by opencv, if user need more image analysis capabilities, please compile opencv particularlly. -/// \param input: CVTensor containing the not decoded image 1D bytes -/// \param output: Decoded image Tensor of shape and type DE_UINT8. Pixel order is RGB +/// \param[in] input CVTensor containing the not decoded image 1D bytes +/// \param[out] output Decoded image Tensor of shape and type DE_UINT8. Pixel order is RGB Status Decode(const std::shared_ptr &input, std::shared_ptr *output); /// \brief Get jpeg image width and height -/// \param input: CVTensor containing the not decoded image 1D bytes -/// \param img_width: the jpeg image width -/// \param img_height: the jpeg image height +/// \param[in] input CVTensor containing the not decoded image 1D bytes +/// \param[in] img_width The jpeg image width +/// \param[in] img_height The jpeg image height Status GetJpegImageInfo(const std::shared_ptr &input, int *img_width, int *img_height); /// \brief Returns Normalized image -/// \param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -/// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order -/// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order -/// \param output: Normalized image Tensor of same input shape and type DE_FLOAT32 +/// \param[in] input Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +/// \param[in] mean Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order +/// \param[in] std Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order +/// \param[out] output Normalized image Tensor of same input shape and type DE_FLOAT32 Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &mean, const std::shared_ptr &std); /// \brief Returns Resized image. -/// \param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -/// \param output_height: height of output -/// \param output_width: width of output -/// \param fx: horizontal scale -/// \param fy: vertical scale -/// \param InterpolationMode: the interpolation mode -/// \param output: Resized image of shape or +/// \param[in] input +/// \param[in] output_height Height of output +/// \param[in] output_width Width of output +/// \param[in] fx Horizontal scale +/// \param[in] fy Vertical scale +/// \param[in] InterpolationMode The interpolation mode +/// \param[out] output Resized image of shape or /// and same type as input Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, int32_t output_width, double fx = 0.0, double fy = 0.0, InterpolationMode mode = InterpolationMode::kLinear); /// \brief Pads the input image and puts the padded image in the output -/// \param input: input Tensor -/// \param output: padded Tensor -/// \param pad_top: amount of padding done in top -/// \param pad_bottom: amount of padding done in bottom -/// \param pad_left: amount of padding done in left -/// \param pad_right: amount of padding done in right -/// \param border_types: the interpolation to be done in the border -/// \param fill_r: red fill value for pad -/// \param fill_g: green fill value for pad -/// \param fill_b: blue fill value for pad. +/// \param[in] input: input Tensor +/// \param[out] output: padded Tensor +/// \param[in] pad_top Amount of padding done in top +/// \param[in] pad_bottom Amount of padding done in bottom +/// \param[in] pad_left Amount of padding done in left +/// \param[in] pad_right Amount of padding done in right +/// \param[in] border_types The interpolation to be done in the border +/// \param[in] fill_r Red fill value for pad +/// \param[in] fill_g Green fill value for pad +/// \param[in] fill_b Blue fill value for pad Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); /// \brief Rotate the input image by orientation -/// \param input: input Tensor -/// \param output: padded Tensor -/// \param orientation: the orientation of EXIF +/// \param[in] input Input Tensor +/// \param[out] output Rotated Tensor +/// \param[in] orientation The orientation of EXIF Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, const uint64_t orientation); +/// \brief Geometrically transform the input image +/// \param[in] input Input Tensor +/// \param[out] output Transformed Tensor +/// \param[in] mat The transformation matrix +/// \param[in] interpolation The interpolation mode, support only bilinear for now +/// \param[in] fill_r Red fill value for pad +/// \param[in] fill_g Green fill value for pad +/// \param[in] fill_b Blue fill value for pad +Status Affine(const std::shared_ptr &input, std::shared_ptr *output, const std::vector &mat, + InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.cc index ce8205c86d..8f5808e60e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.cc @@ -16,8 +16,6 @@ #include "minddata/dataset/kernels/image/math_utils.h" -#include - #include #include diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.h index 66c28c0ae3..ac24ae87e8 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/math_utils.h @@ -21,6 +21,8 @@ #include #include "minddata/dataset/util/status.h" +#define CV_PI 3.1415926535897932384626433832795 + namespace mindspore { namespace dataset { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc index c57065e304..a236fe6571 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc @@ -19,7 +19,11 @@ #include #include "minddata/dataset/kernels/image/random_affine_op.h" +#ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/image_utils.h" +#else +#include "minddata/dataset/kernels/image/lite_image_utils.h" +#endif #include "minddata/dataset/kernels/image/math_utils.h" #include "minddata/dataset/util/random.h" diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.h index dcaad46817..1d2ca9703a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.h @@ -21,7 +21,6 @@ #include #include -#include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/kernels/image/affine_op.h" #include "minddata/dataset/util/status.h" @@ -51,7 +50,6 @@ class RandomAffineOp : public AffineOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; private: - std::string kRandomAffineOp = "RandomAffineOp"; std::vector degrees_range_; // min_degree, max_degree std::vector translate_range_; // maximum x translation percentage, maximum y translation percentage std::vector scale_range_; // min_scale, max_scale diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc index 4229416d86..ca54780ed2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc @@ -21,6 +21,7 @@ #include "minddata/dataset/kernels/image/image_utils.h" #endif // Kernel image headers (in alphabetical order) +#include "minddata/dataset/kernels/image/affine_op.h" #ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/auto_contrast_op.h" #include "minddata/dataset/kernels/image/bounding_box_augment_op.h" @@ -42,7 +43,9 @@ #include "minddata/dataset/kernels/image/normalize_pad_op.h" #ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/pad_op.h" +#endif #include "minddata/dataset/kernels/image/random_affine_op.h" +#ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/random_color_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" @@ -88,6 +91,59 @@ namespace vision { /* ####################################### Derived TensorOperation classes ################################# */ // (In alphabetical order) + +// AffineOperation +AffineOperation::AffineOperation(float_t degrees, const std::vector &translation, float scale, + const std::vector &shear, InterpolationMode interpolation, + const std::vector &fill_value) + : degrees_(degrees), + translation_(translation), + scale_(scale), + shear_(shear), + interpolation_(interpolation), + fill_value_(fill_value) {} + +Status AffineOperation::ValidateParams() { + // Translate + if (translation_.size() != 2) { + std::string err_msg = + "Affine: translate expecting size 2, got: translation.size() = " + std::to_string(translation_.size()); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + RETURN_IF_NOT_OK(ValidateScalar("Affine", "translate", translation_[0], {-1, 1}, false, false)); + RETURN_IF_NOT_OK(ValidateScalar("Affine", "translate", translation_[1], {-1, 1}, false, false)); + + // Shear + if (shear_.size() != 2) { + std::string err_msg = "Affine: shear_ranges expecting size 2, got: shear.size() = " + std::to_string(shear_.size()); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + // Fill Value + RETURN_IF_NOT_OK(ValidateVectorFillvalue("Affine", fill_value_)); + + return Status::OK(); +} + +std::shared_ptr AffineOperation::Build() { + std::shared_ptr tensor_op = + std::make_shared(degrees_, translation_, scale_, shear_, interpolation_, fill_value_); + return tensor_op; +} + +Status AffineOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["degrees"] = degrees_; + args["translate"] = translation_; + args["scale"] = scale_; + args["shear"] = shear_; + args["resample"] = interpolation_; + args["fill_value"] = fill_value_; + *out_json = args; + return Status::OK(); +} + #ifndef ENABLE_ANDROID // AutoContrastOperation @@ -257,6 +313,7 @@ Status CutOutOperation::to_json(nlohmann::json *out_json) { *out_json = args; return Status::OK(); } +#endif // DecodeOperation DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {} @@ -269,6 +326,7 @@ Status DecodeOperation::to_json(nlohmann::json *out_json) { (*out_json)["rgb"] = rgb_; return Status::OK(); } +#ifndef ENABLE_ANDROID // EqualizeOperation Status EqualizeOperation::ValidateParams() { return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h index 0124359c54..953c0fecf7 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h @@ -35,6 +35,7 @@ namespace dataset { namespace vision { // Char arrays storing name of corresponding classes (in alphabetical order) +constexpr char kAffineOperation[] = "Affine"; constexpr char kAutoContrastOperation[] = "AutoContrast"; constexpr char kBoundingBoxAugmentOperation[] = "BoundingBoxAugment"; constexpr char kCenterCropOperation[] = "CenterCrop"; @@ -81,9 +82,34 @@ constexpr char kUniformAugOperation[] = "UniformAug"; /* ####################################### Derived TensorOperation classes ################################# */ +class AffineOperation : public TensorOperation { + public: + explicit AffineOperation(float_t degrees, const std::vector &translation, float scale, + const std::vector &shear, InterpolationMode interpolation, + const std::vector &fill_value); + + ~AffineOperation() = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kAffineOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + float degrees_; + std::vector translation_; + float scale_; + std::vector shear_; + InterpolationMode interpolation_; + std::vector fill_value_; +}; + class AutoContrastOperation : public TensorOperation { public: - explicit AutoContrastOperation(float cutoff = 0.0, std::vector ignore = {}); + explicit AutoContrastOperation(float cutoff, std::vector ignore); ~AutoContrastOperation() = default; @@ -102,7 +128,7 @@ class AutoContrastOperation : public TensorOperation { class BoundingBoxAugmentOperation : public TensorOperation { public: - explicit BoundingBoxAugmentOperation(std::shared_ptr transform, float ratio = 0.3); + explicit BoundingBoxAugmentOperation(std::shared_ptr transform, float ratio); ~BoundingBoxAugmentOperation() = default; @@ -156,7 +182,7 @@ class CropOperation : public TensorOperation { class CutMixBatchOperation : public TensorOperation { public: - explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha = 1.0, float prob = 1.0); + explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob); ~CutMixBatchOperation() = default; @@ -176,7 +202,7 @@ class CutMixBatchOperation : public TensorOperation { class CutOutOperation : public TensorOperation { public: - explicit CutOutOperation(int32_t length, int32_t num_patches = 1); + explicit CutOutOperation(int32_t length, int32_t num_patches); ~CutOutOperation() = default; @@ -195,7 +221,7 @@ class CutOutOperation : public TensorOperation { class DecodeOperation : public TensorOperation { public: - explicit DecodeOperation(bool rgb = true); + explicit DecodeOperation(bool rgb); ~DecodeOperation() = default; @@ -246,7 +272,7 @@ class InvertOperation : public TensorOperation { class MixUpBatchOperation : public TensorOperation { public: - explicit MixUpBatchOperation(float alpha = 1); + explicit MixUpBatchOperation(float alpha); ~MixUpBatchOperation() = default; @@ -283,8 +309,7 @@ class NormalizeOperation : public TensorOperation { class NormalizePadOperation : public TensorOperation { public: - NormalizePadOperation(const std::vector &mean, const std::vector &std, - const std::string &dtype = "float32"); + NormalizePadOperation(const std::vector &mean, const std::vector &std, const std::string &dtype); ~NormalizePadOperation() = default; @@ -304,8 +329,7 @@ class NormalizePadOperation : public TensorOperation { class PadOperation : public TensorOperation { public: - PadOperation(std::vector padding, std::vector fill_value = {0}, - BorderType padding_mode = BorderType::kConstant); + PadOperation(std::vector padding, std::vector fill_value, BorderType padding_mode); ~PadOperation() = default; @@ -325,11 +349,9 @@ class PadOperation : public TensorOperation { class RandomAffineOperation : public TensorOperation { public: - RandomAffineOperation(const std::vector °rees, const std::vector &translate_range = {0.0, 0.0}, - const std::vector &scale_range = {1.0, 1.0}, - const std::vector &shear_ranges = {0.0, 0.0, 0.0, 0.0}, - InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, - const std::vector &fill_value = {0, 0, 0}); + RandomAffineOperation(const std::vector °rees, const std::vector &translate_range, + const std::vector &scale_range, const std::vector &shear_ranges, + InterpolationMode interpolation, const std::vector &fill_value); ~RandomAffineOperation() = default; @@ -371,8 +393,8 @@ class RandomColorOperation : public TensorOperation { class RandomColorAdjustOperation : public TensorOperation { public: - RandomColorAdjustOperation(std::vector brightness = {1.0, 1.0}, std::vector contrast = {1.0, 1.0}, - std::vector saturation = {1.0, 1.0}, std::vector hue = {0.0, 0.0}); + RandomColorAdjustOperation(std::vector brightness, std::vector contrast, std::vector saturation, + std::vector hue); ~RandomColorAdjustOperation() = default; @@ -393,9 +415,8 @@ class RandomColorAdjustOperation : public TensorOperation { class RandomCropOperation : public TensorOperation { public: - RandomCropOperation(std::vector size, std::vector padding = {0, 0, 0, 0}, - bool pad_if_needed = false, std::vector fill_value = {0, 0, 0}, - BorderType padding_mode = BorderType::kConstant); + RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, + std::vector fill_value, BorderType padding_mode); ~RandomCropOperation() = default; @@ -417,10 +438,8 @@ class RandomCropOperation : public TensorOperation { class RandomResizedCropOperation : public TensorOperation { public: - RandomResizedCropOperation(std::vector size, std::vector scale = {0.08, 1.0}, - std::vector ratio = {3. / 4., 4. / 3.}, - InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, - int32_t max_attempts = 10); + RandomResizedCropOperation(std::vector size, std::vector scale, std::vector ratio, + InterpolationMode interpolation, int32_t max_attempts); /// \brief default copy constructor explicit RandomResizedCropOperation(const RandomResizedCropOperation &) = default; @@ -461,9 +480,8 @@ class RandomCropDecodeResizeOperation : public RandomResizedCropOperation { class RandomCropWithBBoxOperation : public TensorOperation { public: - RandomCropWithBBoxOperation(std::vector size, std::vector padding = {0, 0, 0, 0}, - bool pad_if_needed = false, std::vector fill_value = {0, 0, 0}, - BorderType padding_mode = BorderType::kConstant); + RandomCropWithBBoxOperation(std::vector size, std::vector padding, bool pad_if_needed, + std::vector fill_value, BorderType padding_mode); ~RandomCropWithBBoxOperation() = default; @@ -485,7 +503,7 @@ class RandomCropWithBBoxOperation : public TensorOperation { class RandomHorizontalFlipOperation : public TensorOperation { public: - explicit RandomHorizontalFlipOperation(float probability = 0.5); + explicit RandomHorizontalFlipOperation(float probability); ~RandomHorizontalFlipOperation() = default; @@ -503,7 +521,7 @@ class RandomHorizontalFlipOperation : public TensorOperation { class RandomHorizontalFlipWithBBoxOperation : public TensorOperation { public: - explicit RandomHorizontalFlipWithBBoxOperation(float probability = 0.5); + explicit RandomHorizontalFlipWithBBoxOperation(float probability); ~RandomHorizontalFlipWithBBoxOperation() = default; @@ -521,7 +539,7 @@ class RandomHorizontalFlipWithBBoxOperation : public TensorOperation { class RandomPosterizeOperation : public TensorOperation { public: - explicit RandomPosterizeOperation(const std::vector &bit_range = {4, 8}); + explicit RandomPosterizeOperation(const std::vector &bit_range); ~RandomPosterizeOperation() = default; @@ -575,10 +593,9 @@ class RandomResizeWithBBoxOperation : public TensorOperation { class RandomResizedCropWithBBoxOperation : public TensorOperation { public: - explicit RandomResizedCropWithBBoxOperation(std::vector size, std::vector scale = {0.08, 1.0}, - std::vector ratio = {3. / 4., 4. / 3.}, - InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, - int32_t max_attempts = 10); + explicit RandomResizedCropWithBBoxOperation(std::vector size, std::vector scale, + std::vector ratio, InterpolationMode interpolation, + int32_t max_attempts); ~RandomResizedCropWithBBoxOperation() = default; @@ -642,7 +659,7 @@ class RandomSelectSubpolicyOperation : public TensorOperation { class RandomSharpnessOperation : public TensorOperation { public: - explicit RandomSharpnessOperation(std::vector degrees = {0.1, 1.9}); + explicit RandomSharpnessOperation(std::vector degrees); ~RandomSharpnessOperation() = default; @@ -678,7 +695,7 @@ class RandomSolarizeOperation : public TensorOperation { class RandomVerticalFlipOperation : public TensorOperation { public: - explicit RandomVerticalFlipOperation(float probability = 0.5); + explicit RandomVerticalFlipOperation(float probability); ~RandomVerticalFlipOperation() = default; @@ -696,7 +713,7 @@ class RandomVerticalFlipOperation : public TensorOperation { class RandomVerticalFlipWithBBoxOperation : public TensorOperation { public: - explicit RandomVerticalFlipWithBBoxOperation(float probability = 0.5); + explicit RandomVerticalFlipWithBBoxOperation(float probability); ~RandomVerticalFlipWithBBoxOperation() = default; @@ -733,8 +750,7 @@ class RescaleOperation : public TensorOperation { class ResizeOperation : public TensorOperation { public: - explicit ResizeOperation(std::vector size, - InterpolationMode interpolation_mode = InterpolationMode::kLinear); + explicit ResizeOperation(std::vector size, InterpolationMode interpolation_mode); ~ResizeOperation() = default; @@ -753,8 +769,7 @@ class ResizeOperation : public TensorOperation { class ResizeWithBBoxOperation : public TensorOperation { public: - explicit ResizeWithBBoxOperation(std::vector size, - InterpolationMode interpolation_mode = InterpolationMode::kLinear); + explicit ResizeWithBBoxOperation(std::vector size, InterpolationMode interpolation_mode); ~ResizeWithBBoxOperation() = default; @@ -870,7 +885,7 @@ class SwapRedBlueOperation : public TensorOperation { class UniformAugOperation : public TensorOperation { public: - explicit UniformAugOperation(std::vector> transforms, int32_t num_ops = 2); + explicit UniformAugOperation(std::vector> transforms, int32_t num_ops); ~UniformAugOperation() = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index e2731d0319..c5ad338694 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -53,6 +53,7 @@ namespace dataset { constexpr char kTensorOp[] = "TensorOp"; // image +constexpr char kAffineOp[] = "AffineOp"; constexpr char kAutoContrastOp[] = "AutoContrastOp"; constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kDecodeOp[] = "DecodeOp"; @@ -73,6 +74,7 @@ constexpr char kMixUpBatchOp[] = "MixUpBatchOp"; constexpr char kNormalizeOp[] = "NormalizeOp"; constexpr char kNormalizePadOp[] = "NormalizePadOp"; constexpr char kPadOp[] = "PadOp"; +constexpr char kRandomAffineOp[] = "RandomAffineOp"; constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp"; diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index 0a1af6c9ad..4a41bd8c8f 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -229,7 +229,7 @@ def set_auto_num_workers(enable): If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by ds.config.set_num_parallel_workers(). - For now, this function is only optimized for Yolo3 dataset with per_batch_map (running map in batch). + For now, this function is only optimized for YoloV3 dataset with per_batch_map (running map in batch). This feature aims to provide a baseline for optimized num_workers assignment for each op. Op whose num_parallel_workers is adjusted to a new value will be logged. diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt index f2378c847b..4b32e8f1cc 100644 --- a/mindspore/lite/minddata/CMakeLists.txt +++ b/mindspore/lite/minddata/CMakeLists.txt @@ -192,9 +192,13 @@ if(BUILD_MINDDATA STREQUAL "full") ${MINDDATA_DIR}/kernels/image/lite_image_utils.cc ${MINDDATA_DIR}/kernels/image/center_crop_op.cc ${MINDDATA_DIR}/kernels/image/crop_op.cc + ${MINDDATA_DIR}/kernels/image/decode_op.cc ${MINDDATA_DIR}/kernels/image/normalize_op.cc + ${MINDDATA_DIR}/kernels/image/affine_op.cc ${MINDDATA_DIR}/kernels/image/resize_op.cc ${MINDDATA_DIR}/kernels/image/rotate_op.cc + ${MINDDATA_DIR}/kernels/image/random_affine_op.cc + ${MINDDATA_DIR}/kernels/image/math_utils.cc ${MINDDATA_DIR}/kernels/data/compose_op.cc ${MINDDATA_DIR}/kernels/data/duplicate_op.cc ${MINDDATA_DIR}/kernels/data/one_hot_op.cc @@ -350,7 +354,6 @@ elseif(BUILD_MINDDATA STREQUAL "lite") "${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc" "${MINDDATA_DIR}/kernels/image/image_utils.cc" "${MINDDATA_DIR}/kernels/image/invert_op.cc" - "${MINDDATA_DIR}/kernels/image/math_utils.cc" "${MINDDATA_DIR}/kernels/image/mixup_batch_op.cc" "${MINDDATA_DIR}/kernels/image/pad_op.cc" "${MINDDATA_DIR}/kernels/image/posterize_op.cc" diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index d63ea56b0c..74443eb24a 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -1,6 +1,7 @@ include(GoogleTest) SET(DE_UT_SRCS + affine_op_test.cc execute_test.cc album_op_test.cc arena_test.cc @@ -11,6 +12,7 @@ SET(DE_UT_SRCS btree_test.cc buddy_test.cc build_vocab_test.cc + c_api_affine_test.cc c_api_cache_test.cc c_api_dataset_album_test.cc c_api_dataset_cifar_test.cc diff --git a/tests/ut/cpp/dataset/affine_op_test.cc b/tests/ut/cpp/dataset/affine_op_test.cc new file mode 100644 index 0000000000..33fb5e9dec --- /dev/null +++ b/tests/ut/cpp/dataset/affine_op_test.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/cvop_common.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/affine_op.h" +#include "minddata/dataset/kernels/image/math_utils.h" +#include +#include +#include "lite_cv/lite_mat.h" +#include "lite_cv/image_process.h" + +using namespace mindspore::dataset; +using mindspore::dataset::InterpolationMode; + +class MindDataTestAffineOp : public UT::CVOP::CVOpCommon { + public: + MindDataTestAffineOp() : CVOpCommon() {} +}; + +// Helper function, consider moving this to helper class for UT +double Mse(cv::Mat img1, cv::Mat img2) { + // clone to get around open cv optimization + cv::Mat output1 = img1.clone(); + cv::Mat output2 = img2.clone(); + + // input check + if (output1.rows < 0 || output1.rows != output2.rows || output1.cols < 0 || output1.cols != output2.cols) { + return 10000.0; + } + return cv::norm(output1, output2, cv::NORM_L1); +} + +// helper function to generate corresponding affine matrix +std::vector GenerateMatrix(const std::shared_ptr &input, float_t degrees, + const std::vector &translation, float_t scale, + const std::vector &shear) { + float_t translation_x = translation[0]; + float_t translation_y = translation[1]; + DegreesToRadians(degrees, °rees); + float_t shear_x = shear[0]; + float_t shear_y = shear[1]; + DegreesToRadians(shear_x, &shear_x); + DegreesToRadians(-1 * shear_y, &shear_y); + float_t cx = ((input->shape()[1] - 1) / 2.0); + float_t cy = ((input->shape()[0] - 1) / 2.0); + // Calculate RSS + std::vector matrix{ + static_cast(scale * cos(degrees + shear_y) / cos(shear_y)), + static_cast(scale * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees))), + 0, + static_cast(scale * sin(degrees + shear_y) / cos(shear_y)), + static_cast(scale * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees))), + 0}; + // Compute T * C * RSS * C^-1 + matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x; + matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y; + return matrix; +} + +TEST_F(MindDataTestAffineOp, TestAffineLite) { + MS_LOG(INFO) << "Doing MindDataTestAffine-TestAffineLite."; + + // create input tensor and + float degree = 0.0; + std::vector translation = {0.0, 0.0}; + float scale = 0.0; + std::vector shear = {0.0, 0.0}; + + // Create affine object with default values + std::shared_ptr op(new AffineOp(degree, translation, scale, shear, InterpolationMode::kLinear)); + // output tensor + std::shared_ptr output_tensor; + + // output + LiteMat dst; + LiteMat lite_mat_rgb(input_tensor_->shape()[1], input_tensor_->shape()[0], input_tensor_->shape()[2], + const_cast(reinterpret_cast(input_tensor_->GetBuffer())), + LDataType::UINT8); + + std::vector matrix = GenerateMatrix(input_tensor_, degree, translation, scale, shear); + + int height = lite_mat_rgb.height_; + int width = lite_mat_rgb.width_; + std::vector dsize; + dsize.push_back(width); + dsize.push_back(height); + double M[6] = {}; + for (int i = 0; i < matrix.size(); i++) { + M[i] = static_cast(matrix[i]); + } + + EXPECT_TRUE(Affine(lite_mat_rgb, dst, M, dsize, UINT8_C3(0, 0, 0))); + Status s = op->Compute(input_tensor_, &output_tensor); + EXPECT_TRUE(s.IsOk()); + // output tensor is a cv tenosr, we can compare mat values + cv::Mat lite_cv_out(dst.height_, dst.width_, CV_8UC3, dst.data_ptr_); + double mse = Mse(lite_cv_out, CVTensor(output_tensor).mat()); + MS_LOG(INFO) << "mse: " << std::to_string(mse) << std::endl; + EXPECT_LT(mse, 1); // predetermined magic number +} diff --git a/tests/ut/cpp/dataset/c_api_affine_test.cc b/tests/ut/cpp/dataset/c_api_affine_test.cc new file mode 100644 index 0000000000..ab28c0b205 --- /dev/null +++ b/tests/ut/cpp/dataset/c_api_affine_test.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/include/vision.h" + +using namespace mindspore::dataset; +using mindspore::dataset::InterpolationMode; +using mindspore::dataset::Tensor; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +TEST_F(MindDataTestPipeline, TestAffineAPI) { + MS_LOG(INFO) << "Doing MindDataTestAffine-TestAffineAPI."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 5)); + + // Create a Repeat operation on ds + int32_t repeat_num = 3; + ds = ds->Repeat(repeat_num); + + // Create auto contrast object with default values + + std::shared_ptr crop(new vision::RandomCrop({256, 256})); + std::shared_ptr affine( + new vision::Affine(0.0, {0.0, 0.0}, 0.0, {0.0, 0.0}, InterpolationMode::kLinear)); + + // Create a Map operation on ds + ds = ds->Map({crop, affine}); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + + // Iterate the dataset and get each row + std::unordered_map row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + // auto image = row["image"]; + // MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + // EXPECT_EQ(row["image"].Shape()[0], 256); + } + + EXPECT_EQ(i, 15); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestAffineAPIFail) { + MS_LOG(INFO) << "Doing MindDataTestAffine-TestAffineAPI."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 5)); + + // Create a Repeat operation on ds + int32_t repeat_num = 3; + ds = ds->Repeat(repeat_num); + + // Create auto contrast object with default values + + std::shared_ptr crop(new vision::RandomCrop({256, 256})); + std::shared_ptr affine( + new vision::Affine(0.0, {2.0, -1.0}, 0.0, {0.0, 0.0}, InterpolationMode::kLinear)); + + // Create a Map operation on ds + ds = ds->Map({crop, affine}); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} +