|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
"""image_ops"""
|
|
|
|
|
from ... import context
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
@ -84,6 +85,7 @@ class CropAndResize(PrimitiveWithInfer):
|
|
|
|
|
self.method = method
|
|
|
|
|
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
|
|
|
|
|
self.extrapolation_value = extrapolation_value
|
|
|
|
|
self.is_ge = context.get_context("enable_ge")
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, boxes, box_index, crop_size):
|
|
|
|
|
# get shape
|
|
|
|
@ -124,6 +126,9 @@ class CropAndResize(PrimitiveWithInfer):
|
|
|
|
|
crop_height = crop_size_value[0]
|
|
|
|
|
crop_width = crop_size_value[1]
|
|
|
|
|
depth = x_shape[3]
|
|
|
|
|
return {'shape': (num_boxes, crop_height, crop_width, depth),
|
|
|
|
|
out_shape = (num_boxes, crop_height, crop_width, depth)
|
|
|
|
|
if self.is_ge:
|
|
|
|
|
out_shape = (num_boxes, x_shape[1], crop_height, crop_width)
|
|
|
|
|
return {'shape': out_shape,
|
|
|
|
|
'dtype': mstype.float32,
|
|
|
|
|
'value': None}
|
|
|
|
|