|
|
|
@ -386,15 +386,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name):
|
|
|
|
|
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _get_bbox(rank, shape, central_fraction):
|
|
|
|
|
def _get_bbox(rank, shape, size_h, size_w):
|
|
|
|
|
"""get bbox start and size for slice"""
|
|
|
|
|
if rank == 3:
|
|
|
|
|
c, h, w = shape
|
|
|
|
|
else:
|
|
|
|
|
n, c, h, w = shape
|
|
|
|
|
|
|
|
|
|
bbox_h_start = int(np.round((float(h) - float(h) * central_fraction) / 2))
|
|
|
|
|
bbox_w_start = int(np.round((float(w) - float(w) * central_fraction) / 2))
|
|
|
|
|
bbox_h_start = int((float(h) - size_h) / 2)
|
|
|
|
|
bbox_w_start = int((float(w) - size_w) / 2)
|
|
|
|
|
bbox_h_size = h - bbox_h_start * 2
|
|
|
|
|
bbox_w_size = w - bbox_w_start * 2
|
|
|
|
|
|
|
|
|
@ -436,12 +436,15 @@ class CentralCrop(Cell):
|
|
|
|
|
def construct(self, image):
|
|
|
|
|
image_shape = F.shape(image)
|
|
|
|
|
rank = len(image_shape)
|
|
|
|
|
h, w = image_shape[-2], image_shape[-1]
|
|
|
|
|
if not rank in (3, 4):
|
|
|
|
|
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
|
|
|
|
|
if self.central_fraction == 1.0:
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction)
|
|
|
|
|
size_h = self.central_fraction * h
|
|
|
|
|
size_w = self.central_fraction * w
|
|
|
|
|
bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w)
|
|
|
|
|
image = self.slice(image, bbox_begin, bbox_size)
|
|
|
|
|
|
|
|
|
|
return image
|
|
|
|
|