|
|
@ -2711,7 +2711,7 @@ class ROIAlign(PrimitiveWithInfer):
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> input_tensor = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32)
|
|
|
|
>>> input_tensor = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32)
|
|
|
|
>>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32)
|
|
|
|
>>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32)
|
|
|
|
>>> roi_align = P.ROIAlign(1, 1, 0.5, 2)
|
|
|
|
>>> roi_align = P.ROIAlign(2, 2, 0.5, 2)
|
|
|
|
>>> output_tensor = roi_align(input_tensor, rois)
|
|
|
|
>>> output_tensor = roi_align(input_tensor, rois)
|
|
|
|
>>> assert output_tensor == Tensor(np.array([[[[2.15]]]]), mindspore.float32)
|
|
|
|
>>> assert output_tensor == Tensor(np.array([[[[2.15]]]]), mindspore.float32)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -4980,4 +4980,5 @@ class LRN(PrimitiveWithInfer):
|
|
|
|
return x_dtype
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
|
|
|
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name)
|
|
|
|
return x_shape
|
|
|
|
return x_shape
|
|
|
|