!3850 Implementing RandomAffine C Op and a CPP UT for AutoContrast & Equalize

Merge pull request !3850 from islam_amin/randomaffine_op
pull/3850/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2cab58fdbc

@ -31,6 +31,7 @@
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_affine_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
@ -115,6 +116,19 @@ PYBIND_REGISTER(ResizeWithBBoxOp, 1, ([](const py::module *m) {
py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation);
}));
PYBIND_REGISTER(RandomAffineOp, 1, ([](const py::module *m) {
(void)py::class_<RandomAffineOp, TensorOp, std::shared_ptr<RandomAffineOp>>(
*m, "RandomAffineOp", "Tensor operation to apply random affine transformations on an image.")
.def(py::init<std::vector<float_t>, std::vector<float_t>, std::vector<float_t>,
std::vector<float_t>, InterpolationMode, std::vector<uint8_t>>(),
py::arg("degrees") = RandomAffineOp::kDegreesRange,
py::arg("translate_range") = RandomAffineOp::kTranslationPercentages,
py::arg("scale_range") = RandomAffineOp::kScaleRange,
py::arg("shear_ranges") = RandomAffineOp::kShearRanges,
py::arg("interpolation") = RandomAffineOp::kDefInterpolation,
py::arg("fill_value") = RandomAffineOp::kFillValue);
}));
PYBIND_REGISTER(
RandomResizeWithBBoxOp, 1, ([](const py::module *m) {
(void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>(

@ -25,6 +25,7 @@
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/data/one_hot_op.h"
#include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_affine_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
@ -136,6 +137,22 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
return op;
}
// Function to create RandomAffineOperation.
std::shared_ptr<RandomAffineOperation> RandomAffine(const std::vector<float_t> &degrees,
const std::vector<float_t> &translate_range,
const std::vector<float_t> &scale_range,
const std::vector<float_t> &shear_ranges,
InterpolationMode interpolation,
const std::vector<uint8_t> &fill_value) {
auto op = std::make_shared<RandomAffineOperation>(degrees, translate_range, scale_range, shear_ranges, interpolation,
fill_value);
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create RandomCropOperation.
std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding,
bool pad_if_needed, std::vector<uint8_t> fill_value,
@ -452,6 +469,82 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
return tensor_op;
}
// RandomAffineOperation
RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> &degrees,
const std::vector<float_t> &translate_range,
const std::vector<float_t> &scale_range,
const std::vector<float_t> &shear_ranges, InterpolationMode interpolation,
const std::vector<uint8_t> &fill_value)
: degrees_(degrees),
translate_range_(translate_range),
scale_range_(scale_range),
shear_ranges_(shear_ranges),
interpolation_(interpolation),
fill_value_(fill_value) {}
bool RandomAffineOperation::ValidateParams() {
// Degrees
if (degrees_.size() != 2) {
MS_LOG(ERROR) << "RandomAffine: degrees vector has incorrect size: degrees.size() = " << degrees_.size();
return false;
}
if (degrees_[0] > degrees_[1]) {
MS_LOG(ERROR) << "RandomAffine: minimum of degrees range is greater than maximum: min = " << degrees_[0]
<< ", max = " << degrees_[1];
return false;
}
// Translate
if (translate_range_.size() != 2) {
MS_LOG(ERROR) << "RandomAffine: translate_range vector has incorrect size: translate_range.size() = "
<< translate_range_.size();
return false;
}
if (translate_range_[0] > translate_range_[1]) {
MS_LOG(ERROR) << "RandomAffine: minimum of translate range is greater than maximum: min = " << translate_range_[0]
<< ", max = " << translate_range_[1];
return false;
}
// Scale
if (scale_range_.size() != 2) {
MS_LOG(ERROR) << "RandomAffine: scale_range vector has incorrect size: scale_range.size() = "
<< scale_range_.size();
return false;
}
if (scale_range_[0] > scale_range_[1]) {
MS_LOG(ERROR) << "RandomAffine: minimum of scale range is greater than maximum: min = " << scale_range_[0]
<< ", max = " << scale_range_[1];
return false;
}
// Shear
if (shear_ranges_.size() != 4) {
MS_LOG(ERROR) << "RandomAffine: shear_ranges vector has incorrect size: shear_ranges.size() = "
<< shear_ranges_.size();
return false;
}
if (shear_ranges_[0] > shear_ranges_[1]) {
MS_LOG(ERROR) << "RandomAffine: minimum of horizontal shear range is greater than maximum: min = "
<< shear_ranges_[0] << ", max = " << shear_ranges_[1];
return false;
}
if (shear_ranges_[2] > shear_ranges_[3]) {
MS_LOG(ERROR) << "RandomAffine: minimum of vertical shear range is greater than maximum: min = " << shear_ranges_[2]
<< ", max = " << scale_range_[3];
return false;
}
// Fill Value
if (fill_value_.size() != 3) {
MS_LOG(ERROR) << "RandomAffine: fill_value vector has incorrect size: fill_value.size() = " << fill_value_.size();
return false;
}
return true;
}
std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
auto tensor_op = std::make_shared<RandomAffineOp>(degrees_, translate_range_, scale_range_, shear_ranges_,
interpolation_, fill_value_);
return tensor_op;
}
// RandomCropOperation
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
std::vector<uint8_t> fill_value, BorderType padding_mode)

@ -55,6 +55,7 @@ class MixUpBatchOperation;
class NormalizeOperation;
class OneHotOperation;
class PadOperation;
class RandomAffineOperation;
class RandomColorAdjustOperation;
class RandomCropOperation;
class RandomHorizontalFlipOperation;
@ -134,6 +135,23 @@ std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes);
std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
BorderType padding_mode = BorderType::kConstant);
/// \brief Function to create a RandomAffine TensorOperation.
/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode.
/// \param[in] degrees A float vector size 2, representing the starting and ending degree
/// \param[in] translate_range A float vector size 2, representing percentages of translation on x and y axes.
/// \param[in] scale_range A float vector size 2, representing the starting and ending scales in the range.
/// \param[in] shear_ranges A float vector size 4, representing the starting and ending shear degrees vertically and
/// horizontally.
/// \param[in] interpolation An enum for the mode of interpolation
/// \param[in] fill_value A uint8_t vector size 3, representing the pixel intensity of the borders, it is used to
/// fill R, G, B channels respectively.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomAffineOperation> RandomAffine(
const std::vector<float_t> &degrees, const std::vector<float_t> &translate_range = {0.0, 0.0},
const std::vector<float_t> &scale_range = {1.0, 1.0}, const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image
/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values
/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
@ -333,6 +351,29 @@ class PadOperation : public TensorOperation {
BorderType padding_mode_;
};
class RandomAffineOperation : public TensorOperation {
public:
RandomAffineOperation(const std::vector<float_t> &degrees, const std::vector<float_t> &translate_range = {0.0, 0.0},
const std::vector<float_t> &scale_range = {1.0, 1.0},
const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0});
~RandomAffineOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
private:
std::vector<float_t> degrees_; // min_degree, max_degree
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
std::vector<float_t> scale_range_; // min_scale, max_scale
std::vector<float_t> shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear
InterpolationMode interpolation_;
std::vector<uint8_t> fill_value_;
};
class RandomColorAdjustOperation : public TensorOperation {
public:
RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},

@ -1,6 +1,7 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(kernels-image OBJECT
affine_op.cc
auto_contrast_op.cc
center_crop_op.cc
crop_op.cc
@ -10,9 +11,11 @@ add_library(kernels-image OBJECT
hwc_to_chw_op.cc
image_utils.cc
invert_op.cc
math_utils.cc
mixup_batch_op.cc
normalize_op.cc
pad_op.cc
random_affine_op.cc
random_color_adjust_op.cc
random_crop_decode_resize_op.cc
random_crop_and_resize_with_bbox_op.cc

@ -0,0 +1,99 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <random>
#include <utility>
#include <vector>
#include "minddata/dataset/kernels/image/affine_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/math_utils.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
const InterpolationMode AffineOp::kDefInterpolation = InterpolationMode::kNearestNeighbour;
const float_t AffineOp::kDegrees = 0.0;
const std::vector<float_t> AffineOp::kTranslation = {0.0, 0.0};
const float_t AffineOp::kScale = 1.0;
const std::vector<float_t> AffineOp::kShear = {0.0, 0.0};
const std::vector<uint8_t> AffineOp::kFillValue = {0, 0, 0};
AffineOp::AffineOp(float_t degrees, const std::vector<float_t> &translation, float_t scale,
const std::vector<float_t> &shear, InterpolationMode interpolation,
const std::vector<uint8_t> &fill_value)
: degrees_(degrees),
translation_(translation),
scale_(scale),
shear_(shear),
interpolation_(interpolation),
fill_value_(fill_value) {}
Status AffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
float_t translation_x = translation_[0];
float_t translation_y = translation_[1];
float_t degrees = 0.0;
DegreesToRadians(degrees_, &degrees);
float_t shear_x = shear_[0];
float_t shear_y = shear_[1];
DegreesToRadians(shear_x, &shear_x);
DegreesToRadians(-1 * shear_y, &shear_y);
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
// Apply Affine Transformation
// T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
// C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
// RSS is rotation with scale and shear matrix
// RSS(a, s, (sx, sy)) =
// = R(a) * S(s) * SHy(sy) * SHx(sx)
// = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
// [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
// [ 0 , 0 , 1 ]
//
// where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
// SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
// [0, 1 ] [-tan(s), 1]
//
// Thus, the affine matrix is M = T * C * RSS * C^-1
float_t cx = ((input_cv->mat().cols - 1) / 2.0);
float_t cy = ((input_cv->mat().rows - 1) / 2.0);
// Calculate RSS
std::vector<float_t> matrix{scale_ * cos(degrees + shear_y) / cos(shear_y),
scale_ * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees)),
0,
scale_ * sin(degrees + shear_y) / cos(shear_y),
scale_ * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees)),
0};
// Compute T * C * RSS * C^-1
matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x;
matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y;
cv::Mat affine_mat(matrix);
affine_mat = affine_mat.reshape(1, {2, 3});
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
cv::warpAffine(input_cv->mat(), output_cv->mat(), affine_mat, input_cv->mat().size(),
GetCVInterpolationMode(interpolation_), cv::BORDER_CONSTANT,
cv::Scalar(fill_value_[0], fill_value_[1], fill_value_[2]));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AFFINE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AFFINE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class AffineOp : public TensorOp {
public:
/// Default values
static const float_t kDegrees;
static const std::vector<float_t> kTranslation;
static const float_t kScale;
static const std::vector<float_t> kShear;
static const InterpolationMode kDefInterpolation;
static const std::vector<uint8_t> kFillValue;
/// Constructor
public:
explicit AffineOp(float_t degrees, const std::vector<float_t> &translation = kTranslation, float_t scale = kScale,
const std::vector<float_t> &shear = kShear, InterpolationMode interpolation = kDefInterpolation,
const std::vector<uint8_t> &fill_value = kFillValue);
~AffineOp() override = default;
std::string Name() const override { return kAffineOp; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
/// Member variables
private:
std::string kAffineOp = "AffineOp";
protected:
float_t degrees_;
std::vector<float_t> translation_; // translation_x and translation_y
float_t scale_;
std::vector<float_t> shear_; // shear_x and shear_y
InterpolationMode interpolation_;
std::vector<uint8_t> fill_value_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AFFINE_OP_H_

@ -21,6 +21,7 @@
#include <utility>
#include <opencv2/imgcodecs.hpp>
#include "utils/ms_utils.h"
#include "minddata/dataset/kernels/image/math_utils.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h"
@ -631,36 +632,9 @@ Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
hist.col(0).copyTo(hist_vec);
// Ignore values in ignore
for (const auto &item : ignore) hist_vec[item] = 0;
int32_t n = std::accumulate(hist_vec.begin(), hist_vec.end(), 0);
// Find pixel values that are in the low cutoff and high cutoff.
int32_t cut = static_cast<int32_t>((cutoff / 100.0) * n);
if (cut != 0) {
for (int32_t lo = 0; lo < 256 && cut > 0; lo++) {
if (cut > hist_vec[lo]) {
cut -= hist_vec[lo];
hist_vec[lo] = 0;
} else {
hist_vec[lo] -= cut;
cut = 0;
}
}
cut = static_cast<int32_t>((cutoff / 100.0) * n);
for (int32_t hi = 255; hi >= 0 && cut > 0; hi--) {
if (cut > hist_vec[hi]) {
cut -= hist_vec[hi];
hist_vec[hi] = 0;
} else {
hist_vec[hi] -= cut;
cut = 0;
}
}
}
int32_t lo = 0;
int32_t hi = 255;
for (; lo < 256 && !hist_vec[lo]; lo++) {
}
for (; hi >= 0 && !hist_vec[hi]; hi--) {
}
int32_t lo = 0;
RETURN_IF_NOT_OK(ComputeUpperAndLowerPercentiles(&hist_vec, cutoff, cutoff, &hi, &lo));
if (hi <= lo) {
for (int32_t i = 0; i < 256; i++) {
table.push_back(i);
@ -685,7 +659,6 @@ Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
(*output) = std::static_pointer_cast<Tensor>(output_cv);
(*output)->Reshape(input->shape());
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Error in auto contrast");

@ -0,0 +1,84 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/math_utils.h"
#include <opencv2/imgproc/types_c.h>
#include <algorithm>
#include <string>
namespace mindspore {
namespace dataset {
Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p, int32_t low_p, int32_t *hi,
int32_t *lo) {
try {
int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
int32_t cut = static_cast<int32_t>((low_p / 100.0) * n);
for (int32_t lb = 0; lb < hist->size() + 1 && cut > 0; lb++) {
if (cut > (*hist)[lb]) {
cut -= (*hist)[lb];
(*hist)[lb] = 0;
} else {
(*hist)[lb] -= cut;
cut = 0;
}
}
cut = static_cast<int32_t>((hi_p / 100.0) * n);
for (int32_t ub = hist->size() - 1; ub >= 0 && cut > 0; ub--) {
if (cut > (*hist)[ub]) {
cut -= (*hist)[ub];
(*hist)[ub] = 0;
} else {
(*hist)[ub] -= cut;
cut = 0;
}
}
*lo = 0;
*hi = hist->size() - 1;
for (; (*lo) < (*hi) && !(*hist)[*lo]; (*lo)++) {
}
for (; (*hi) >= 0 && !(*hist)[*hi]; (*hi)--) {
}
} catch (const std::exception &e) {
const char *err_msg = e.what();
std::string err_message = "Error in ComputeUpperAndLowerPercentiles: ";
err_message += err_msg;
RETURN_STATUS_UNEXPECTED(err_message);
}
return Status::OK();
}
Status DegreesToRadians(float_t degrees, float_t *radians_target) {
*radians_target = CV_PI * degrees / 180.0;
return Status::OK();
}
Status GenerateRealNumber(float_t a, float_t b, std::mt19937 *rnd, float_t *result) {
try {
std::uniform_real_distribution<float_t> distribution{a, b};
*result = distribution(*rnd);
} catch (const std::exception &e) {
const char *err_msg = e.what();
std::string err_message = "Error in GenerateRealNumber: ";
err_message += err_msg;
RETURN_STATUS_UNEXPECTED(err_message);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MATH_UTILS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MATH_UTILS_H_
#include <memory>
#include <random>
#include <vector>
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief Returns lower and upper pth percentiles of the input histogram.
/// \param[in] hist: Input histogram (mutates the histogram for computation purposes)
/// \param[in] hi_p: Right side percentile
/// \param[in] low_p: Left side percentile
/// \param[out] hi: Value at high end percentile
/// \param[out] lo: Value at low end percentile
Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p, int32_t low_p, int32_t *hi,
int32_t *lo);
/// \brief Converts degrees input to radians.
/// \param[in] degrees: Input degrees
/// \param[out] radians_target: Radians output
Status DegreesToRadians(float_t degrees, float_t *radians_target);
/// \brief Generates a random real number in [a,b).
/// \param[in] a: Start of range
/// \param[in] b: End of range
/// \param[in] rnd: Random device
/// \param[out] result: Random number in range [a,b)
Status GenerateRealNumber(float_t a, float_t b, std::mt19937 *rnd, float_t *result);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MATH_UTILS_H_

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <random>
#include <utility>
#include <vector>
#include "minddata/dataset/kernels/image/random_affine_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/math_utils.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
const std::vector<float_t> RandomAffineOp::kDegreesRange = {0.0, 0.0};
const std::vector<float_t> RandomAffineOp::kTranslationPercentages = {0.0, 0.0};
const std::vector<float_t> RandomAffineOp::kScaleRange = {1.0, 1.0};
const std::vector<float_t> RandomAffineOp::kShearRanges = {0.0, 0.0, 0.0, 0.0};
const InterpolationMode RandomAffineOp::kDefInterpolation = InterpolationMode::kNearestNeighbour;
const std::vector<uint8_t> RandomAffineOp::kFillValue = {0, 0, 0};
RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t> translate_range,
std::vector<float_t> scale_range, std::vector<float_t> shear_ranges,
InterpolationMode interpolation, std::vector<uint8_t> fill_value)
: AffineOp(0.0),
degrees_range_(degrees),
translate_range_(translate_range),
scale_range_(scale_range),
shear_ranges_(shear_ranges) {
interpolation_ = interpolation;
fill_value_ = fill_value;
rnd_.seed(GetSeed());
}
Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
dsize_t height = input->shape()[0];
dsize_t width = input->shape()[1];
float_t max_dx = translate_range_[0] * height;
float_t max_dy = translate_range_[1] * width;
float_t degrees = 0.0;
RETURN_IF_NOT_OK(GenerateRealNumber(degrees_range_[0], degrees_range_[1], &rnd_, &degrees));
float_t translation_x = 0.0;
RETURN_IF_NOT_OK(GenerateRealNumber(-1 * max_dx, max_dx, &rnd_, &translation_x));
float_t translation_y = 0.0;
RETURN_IF_NOT_OK(GenerateRealNumber(-1 * max_dy, max_dy, &rnd_, &translation_y));
float_t scale = 1.0;
RETURN_IF_NOT_OK(GenerateRealNumber(scale_range_[0], scale_range_[1], &rnd_, &scale));
float_t shear_x = 0.0;
RETURN_IF_NOT_OK(GenerateRealNumber(shear_ranges_[0], shear_ranges_[1], &rnd_, &shear_x));
float_t shear_y = 0.0;
RETURN_IF_NOT_OK(GenerateRealNumber(shear_ranges_[2], shear_ranges_[3], &rnd_, &shear_y));
// assign to base class variables
degrees_ = degrees;
scale_ = scale;
translation_[0] = translation_x;
translation_[1] = translation_y;
shear_[0] = shear_x;
shear_[1] = shear_y;
return AffineOp::Compute(input, output);
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,64 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AFFINE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AFFINE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/image/affine_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RandomAffineOp : public AffineOp {
public:
/// Default values, also used by python_bindings.cc
static const std::vector<float_t> kDegreesRange;
static const std::vector<float_t> kTranslationPercentages;
static const std::vector<float_t> kScaleRange;
static const std::vector<float_t> kShearRanges;
static const InterpolationMode kDefInterpolation;
static const std::vector<uint8_t> kFillValue;
explicit RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t> translate_range = kTranslationPercentages,
std::vector<float_t> scale_range = kScaleRange,
std::vector<float_t> shear_ranges = kShearRanges,
InterpolationMode interpolation = kDefInterpolation,
std::vector<uint8_t> fill_value = kFillValue);
~RandomAffineOp() override = default;
std::string Name() const override { return kRandomAffineOp; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
private:
std::string kRandomAffineOp = "RandomAffineOp";
std::vector<float_t> degrees_range_; // min_degree, max_degree
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
std::vector<float_t> scale_range_; // min_scale, max_scale
std::vector<float_t> shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear
std::mt19937 rnd_; // random device
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AFFINE_OP_H_

@ -47,7 +47,8 @@ from .utils import Inter, Border
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
FLOAT_MAX_INTEGER
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
@ -170,6 +171,95 @@ class Normalize(cde.NormalizeOp):
super().__init__(*mean, *std)
class RandomAffine(cde.RandomAffineOp):
"""
Apply Random affine transformation to the input PIL image.
Args:
degrees (int or float or sequence): Range of the rotation degrees.
If degrees is a number, the range will be (-degrees, degrees).
If degrees is a sequence, it should be (min, max).
translate (sequence, optional): Sequence (tx, ty) of maximum translation in
x(horizontal) and y(vertical) directions (default=None).
The horizontal and vertical shift is selected randomly from the range:
(-tx*width, tx*width) and (-ty*height, ty*height), respectively.
If None, no translations gets applied.
scale (sequence, optional): Scaling factor interval (default=None, original scale is used).
shear (int or float or sequence, optional): Range of shear factor (default=None).
If a number 'shear', then a shear parallel to the x axis in the range of (-shear, +shear) is applied.
If a tuple or list of size 2, then a shear parallel to the x axis in the range of (shear[0], shear[1])
is applied.
If a tuple of list of size 4, then a shear parallel to x axis in the range of (shear[0], shear[1])
and a shear parallel to y axis in the range of (shear[2], shear[3]) is applied.
If None, no shear is applied.
resample (Inter mode, optional): An optional resampling filter (default=Inter.NEAREST).
If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
It can be any of [Inter.BILINEAR, Inter.NEAREST, Inter.BICUBIC].
- Inter.BILINEAR, means resample method is bilinear interpolation.
- Inter.NEAREST, means resample method is nearest-neighbor interpolation.
- Inter.BICUBIC, means resample method is bicubic interpolation.
fill_value (tuple or int, optional): Optional fill_value to fill the area outside the transform
in the output image. Used only in Pillow versions > 5.0.0 (default=0, filling is performed).
Raises:
ValueError: If degrees is negative.
ValueError: If translation value is not between 0 and 1.
ValueError: If scale is not positive.
ValueError: If shear is a number but is not positive.
TypeError: If degrees is not a number or a list or a tuple.
If degrees is a list or tuple, its length is not 2.
TypeError: If translate is specified but is not list or a tuple of length 2.
TypeError: If scale is not a list or tuple of length 2.''
TypeError: If shear is not a list or tuple of length 2 or 4.
Examples:
>>> c_transform.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1))
"""
@check_random_affine
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0):
# Parameter checking
if shear is not None:
if isinstance(shear, numbers.Number):
shear = (-1 * shear, shear, 0., 0.)
else:
if len(shear) == 2:
shear = [shear[0], shear[1], 0., 0.]
elif len(shear) == 4:
shear = [s for s in shear]
if isinstance(degrees, numbers.Number):
degrees = (-1 * degrees, degrees)
if isinstance(fill_value, numbers.Number):
fill_value = (fill_value, fill_value, fill_value)
# translation
if translate is None:
translate = (0.0, 0.0)
# scale
if scale is None:
scale = (1.0, 1.0)
# shear
if shear is None:
shear = (0.0, 0.0, 0.0, 0.0)
self.degrees = degrees
self.translate = translate
self.scale_ = scale
self.shear = shear
self.resample = DE_C_INTER_MODE[resample]
self.fill_value = fill_value
super().__init__(degrees, translate, scale, shear, DE_C_INTER_MODE[resample], fill_value)
class RandomCrop(cde.RandomCropOp):
"""
Crop the input image at a random location.

@ -4,6 +4,7 @@ SET(DE_UT_SRCS
common/common.cc
common/cvop_common.cc
common/bboxop_common.cc
auto_contrast_op_test.cc
batch_op_test.cc
bit_functions_test.cc
storage_container_test.cc
@ -22,6 +23,7 @@ SET(DE_UT_SRCS
cut_out_op_test.cc
datatype_test.cc
decode_op_test.cc
equalize_op_test.cc
execution_tree_test.cc
global_context_test.cc
main_test.cc
@ -36,6 +38,7 @@ SET(DE_UT_SRCS
path_test.cc
project_op_test.cc
queue_test.cc
random_affine_op_test.cc
random_crop_op_test.cc
random_crop_with_bbox_op_test.cc
random_crop_decode_resize_op_test.cc

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestAutoContrastOp : public UT::CVOP::CVOpCommon {
public:
MindDataTestAutoContrastOp() : CVOpCommon() {}
};
TEST_F(MindDataTestAutoContrastOp, TestOp1) {
MS_LOG(INFO) << "Doing testAutoContrastOp.";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<AutoContrastOp> op(new AutoContrastOp(1.0, {0, 255}));
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
CheckImageShapeAndData(output_tensor, kAutoContrast);
}

@ -521,6 +521,119 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomAffineSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAffineSuccess1 with non-default params.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> affine =
vision::RandomAffine({30.0, 30.0}, {0.0, 0.0}, {2.0, 2.0}, {10.0, 10.0, 20.0, 20.0});
EXPECT_NE(affine, nullptr);
// Create a Map operation on ds
ds = ds->Map({affine});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomAffineSuccess2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAffineSuccess2 with default params.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> affine = vision::RandomAffine({0.0, 0.0});
EXPECT_NE(affine, nullptr);
// Create a Map operation on ds
ds = ds->Map({affine});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomAffineFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAffineFail with invalid params.";
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> affine = vision::RandomAffine({0.0, 0.0}, {});
EXPECT_EQ(affine, nullptr);
// Invalid number of values for translate
affine = vision::RandomAffine({0.0, 0.0}, {1, 1, 1, 1});
EXPECT_EQ(affine, nullptr);
// Invalid number of values for shear
affine = vision::RandomAffine({30.0, 30.0}, {0.0, 0.0}, {2.0, 2.0}, {10.0, 10.0});
EXPECT_EQ(affine, nullptr);
}
TEST_F(MindDataTestPipeline, TestRandomRotation) {
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";

@ -130,6 +130,18 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
expect_image_path = dir_path + "imagefolder/apple_expect_changemode.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_changemode.jpg";
break;
case kRandomAffine:
expect_image_path = dir_path + "imagefolder/apple_expect_randomaffine.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_randomaffine.jpg";
break;
case kAutoContrast:
expect_image_path = dir_path + "imagefolder/apple_expect_autocontrast.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_autocontrast.jpg";
break;
case kEqualize:
expect_image_path = dir_path + "imagefolder/apple_expect_equalize.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_equalize.jpg";
break;
default:
MS_LOG(INFO) << "Not pass verification! Operation type does not exists.";
EXPECT_EQ(0, 1);

@ -37,7 +37,10 @@ class CVOpCommon : public Common {
kChannelSwap,
kChangeMode,
kTemplate,
kCrop
kCrop,
kRandomAffine,
kAutoContrast,
kEqualize
};
CVOpCommon();

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/equalize_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestEqualizeOp : public UT::CVOP::CVOpCommon {
public:
MindDataTestEqualizeOp() : CVOpCommon() {}
};
TEST_F(MindDataTestEqualizeOp, TestOp1) {
MS_LOG(INFO) << "Doing testEqualizeOp.";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<EqualizeOp> op(new EqualizeOp());
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
CheckImageShapeAndData(output_tensor, kEqualize);
}

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/random_affine_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestRandomAffineOp : public UT::CVOP::CVOpCommon {
public:
MindDataTestRandomAffineOp() : CVOpCommon() {}
};
TEST_F(MindDataTestRandomAffineOp, TestOp1) {
MS_LOG(INFO) << "Doing testRandomAffineOp.";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<RandomAffineOp> op(new RandomAffineOp({30.0, 30.0}, {0.0, 0.0}, {2.0, 2.0}, {10.0, 10.0, 20.0, 20.0},
InterpolationMode::kNearestNeighbour, {255, 0, 0}));
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
CheckImageShapeAndData(output_tensor, kRandomAffine);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 565 KiB

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save