|
|
@ -219,8 +219,8 @@ std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Function to create RandomPosterizeOperation.
|
|
|
|
// Function to create RandomPosterizeOperation.
|
|
|
|
std::shared_ptr<RandomPosterizeOperation> RandomPosterize(uint8_t min_bit, uint8_t max_bit) {
|
|
|
|
std::shared_ptr<RandomPosterizeOperation> RandomPosterize(const std::vector<uint8_t> &bit_range) {
|
|
|
|
auto op = std::make_shared<RandomPosterizeOperation>(min_bit, max_bit);
|
|
|
|
auto op = std::make_shared<RandomPosterizeOperation>(bit_range);
|
|
|
|
// Input validation
|
|
|
|
// Input validation
|
|
|
|
if (!op->ValidateParams()) {
|
|
|
|
if (!op->ValidateParams()) {
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
@ -383,7 +383,7 @@ CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format,
|
|
|
|
|
|
|
|
|
|
|
|
bool CutMixBatchOperation::ValidateParams() {
|
|
|
|
bool CutMixBatchOperation::ValidateParams() {
|
|
|
|
if (alpha_ <= 0) {
|
|
|
|
if (alpha_ <= 0) {
|
|
|
|
MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative.";
|
|
|
|
MS_LOG(ERROR) << "CutMixBatch: alpha must be a positive floating value however it is: " << alpha_;
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (prob_ < 0 || prob_ > 1) {
|
|
|
|
if (prob_ < 0 || prob_ > 1) {
|
|
|
@ -616,7 +616,7 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees
|
|
|
|
bool RandomAffineOperation::ValidateParams() {
|
|
|
|
bool RandomAffineOperation::ValidateParams() {
|
|
|
|
// Degrees
|
|
|
|
// Degrees
|
|
|
|
if (degrees_.size() != 2) {
|
|
|
|
if (degrees_.size() != 2) {
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: degrees vector has incorrect size: degrees.size() = " << degrees_.size();
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: degrees expecting size 2, got: degrees.size() = " << degrees_.size();
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (degrees_[0] > degrees_[1]) {
|
|
|
|
if (degrees_[0] > degrees_[1]) {
|
|
|
@ -625,16 +625,43 @@ bool RandomAffineOperation::ValidateParams() {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Translate
|
|
|
|
// Translate
|
|
|
|
if (translate_range_.size() != 2) {
|
|
|
|
if (translate_range_.size() != 2 && translate_range_.size() != 4) {
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: translate_range vector has incorrect size: translate_range.size() = "
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: translate_range expecting size 2 or 4, got: translate_range.size() = "
|
|
|
|
<< translate_range_.size();
|
|
|
|
<< translate_range_.size();
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (translate_range_[0] > translate_range_[1]) {
|
|
|
|
if (translate_range_[0] > translate_range_[1]) {
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: minimum of translate range is greater than maximum: min = " << translate_range_[0]
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on x is greater than maximum: min = "
|
|
|
|
<< ", max = " << translate_range_[1];
|
|
|
|
<< translate_range_[0] << ", max = " << translate_range_[1];
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (translate_range_[0] < -1 || translate_range_[0] > 1) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on x is out of range of [-1, 1], value = "
|
|
|
|
|
|
|
|
<< translate_range_[0];
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (translate_range_[1] < -1 || translate_range_[1] > 1) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: maximum of translate range on x is out of range of [-1, 1], value = "
|
|
|
|
|
|
|
|
<< translate_range_[1];
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (translate_range_.size() == 4) {
|
|
|
|
|
|
|
|
if (translate_range_[2] > translate_range_[3]) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on y is greater than maximum: min = "
|
|
|
|
|
|
|
|
<< translate_range_[2] << ", max = " << translate_range_[3];
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (translate_range_[2] < -1 || translate_range_[2] > 1) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on y is out of range of [-1, 1], value = "
|
|
|
|
|
|
|
|
<< translate_range_[2];
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (translate_range_[3] < -1 || translate_range_[3] > 1) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: maximum of translate range on y is out of range of [-1, 1], value = "
|
|
|
|
|
|
|
|
<< translate_range_[3];
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
// Scale
|
|
|
|
// Scale
|
|
|
|
if (scale_range_.size() != 2) {
|
|
|
|
if (scale_range_.size() != 2) {
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: scale_range vector has incorrect size: scale_range.size() = "
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: scale_range vector has incorrect size: scale_range.size() = "
|
|
|
@ -647,8 +674,8 @@ bool RandomAffineOperation::ValidateParams() {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Shear
|
|
|
|
// Shear
|
|
|
|
if (shear_ranges_.size() != 4) {
|
|
|
|
if (shear_ranges_.size() != 2 && shear_ranges_.size() != 4) {
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: shear_ranges vector has incorrect size: shear_ranges.size() = "
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: shear_ranges expecting size 2 or 4, got: shear_ranges.size() = "
|
|
|
|
<< shear_ranges_.size();
|
|
|
|
<< shear_ranges_.size();
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -657,7 +684,7 @@ bool RandomAffineOperation::ValidateParams() {
|
|
|
|
<< shear_ranges_[0] << ", max = " << shear_ranges_[1];
|
|
|
|
<< shear_ranges_[0] << ", max = " << shear_ranges_[1];
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (shear_ranges_[2] > shear_ranges_[3]) {
|
|
|
|
if (shear_ranges_.size() == 4 && shear_ranges_[2] > shear_ranges_[3]) {
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: minimum of vertical shear range is greater than maximum: min = " << shear_ranges_[2]
|
|
|
|
MS_LOG(ERROR) << "RandomAffine: minimum of vertical shear range is greater than maximum: min = " << shear_ranges_[2]
|
|
|
|
<< ", max = " << scale_range_[3];
|
|
|
|
<< ", max = " << scale_range_[3];
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
@ -671,6 +698,12 @@ bool RandomAffineOperation::ValidateParams() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
|
|
|
|
std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
|
|
|
|
|
|
|
|
if (shear_ranges_.size() == 2) {
|
|
|
|
|
|
|
|
shear_ranges_.resize(4);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (translate_range_.size() == 2) {
|
|
|
|
|
|
|
|
translate_range_.resize(4);
|
|
|
|
|
|
|
|
}
|
|
|
|
auto tensor_op = std::make_shared<RandomAffineOp>(degrees_, translate_range_, scale_range_, shear_ranges_,
|
|
|
|
auto tensor_op = std::make_shared<RandomAffineOp>(degrees_, translate_range_, scale_range_, shear_ranges_,
|
|
|
|
interpolation_, fill_value_);
|
|
|
|
interpolation_, fill_value_);
|
|
|
|
return tensor_op;
|
|
|
|
return tensor_op;
|
|
|
@ -737,27 +770,31 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// RandomPosterizeOperation
|
|
|
|
// RandomPosterizeOperation
|
|
|
|
RandomPosterizeOperation::RandomPosterizeOperation(uint8_t min_bit, uint8_t max_bit)
|
|
|
|
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
|
|
|
|
: min_bit_(min_bit), max_bit_(max_bit) {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool RandomPosterizeOperation::ValidateParams() {
|
|
|
|
bool RandomPosterizeOperation::ValidateParams() {
|
|
|
|
if (min_bit_ < 1 || min_bit_ > 8) {
|
|
|
|
if (bit_range_.size() != 2) {
|
|
|
|
MS_LOG(ERROR) << "RandomPosterize: min_bit value is out of range [1-8]: " << min_bit_;
|
|
|
|
MS_LOG(ERROR) << "RandomPosterize: bit_range needs to be of size 2 but is of size: " << bit_range_.size();
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (bit_range_[0] < 1 || bit_range_[0] > 8) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "RandomPosterize: min_bit value is out of range [1-8]: " << bit_range_[0];
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (max_bit_ < 1 || max_bit_ > 8) {
|
|
|
|
if (bit_range_[1] < 1 || bit_range_[1] > 8) {
|
|
|
|
MS_LOG(ERROR) << "RandomPosterize: max_bit value is out of range [1-8]: " << max_bit_;
|
|
|
|
MS_LOG(ERROR) << "RandomPosterize: max_bit value is out of range [1-8]: " << bit_range_[1];
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (max_bit_ < min_bit_) {
|
|
|
|
if (bit_range_[1] < bit_range_[0]) {
|
|
|
|
MS_LOG(ERROR) << "RandomPosterize: max_bit value is less than min_bit: max =" << max_bit_ << ", min = " << min_bit_;
|
|
|
|
MS_LOG(ERROR) << "RandomPosterize: max_bit value is less than min_bit: max =" << bit_range_[1]
|
|
|
|
|
|
|
|
<< ", min = " << bit_range_[0];
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
|
|
|
|
std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
|
|
|
|
std::shared_ptr<RandomPosterizeOp> tensor_op = std::make_shared<RandomPosterizeOp>(min_bit_, max_bit_);
|
|
|
|
std::shared_ptr<RandomPosterizeOp> tensor_op = std::make_shared<RandomPosterizeOp>(bit_range_);
|
|
|
|
return tensor_op;
|
|
|
|
return tensor_op;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|