|
|
|
@ -32,7 +32,7 @@ CutMixBatchOp::CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, f
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y, int *crop_width, int *crop_height) {
|
|
|
|
|
float cut_ratio = 1 - lam;
|
|
|
|
|
const float cut_ratio = 1 - lam;
|
|
|
|
|
int cut_w = static_cast<int>(width * cut_ratio);
|
|
|
|
|
int cut_h = static_cast<int>(height * cut_ratio);
|
|
|
|
|
std::uniform_int_distribution<int> width_uniform_distribution(0, width);
|
|
|
|
@ -116,7 +116,6 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|
|
|
|
RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height));
|
|
|
|
|
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::HWC));
|
|
|
|
|
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2]));
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
// NCHW Format
|
|
|
|
|
GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width,
|
|
|
|
|