|
|
|
@ -16,7 +16,7 @@
|
|
|
|
|
"""Other operators."""
|
|
|
|
|
from ..._c_expression import signature_rw as sig_rw
|
|
|
|
|
from ..._c_expression import signature_kind as sig_kind
|
|
|
|
|
from ..._checkparam import ParamValidator as validator, Rel
|
|
|
|
|
from ..._checkparam import Validator as validator, Rel
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
|
|
|
|
|
|
|
|
@ -82,22 +82,21 @@ class BoundingBoxEncode(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
|
|
|
|
|
validator.check_type('means', means, [tuple])
|
|
|
|
|
validator.check_type('stds', stds, [tuple])
|
|
|
|
|
validator.check("means len", len(means), '', 4)
|
|
|
|
|
validator.check("stds len", len(stds), '', 4)
|
|
|
|
|
validator.check_value_type('means', means, [tuple], self.name)
|
|
|
|
|
validator.check_value_type('stds', stds, [tuple], self.name)
|
|
|
|
|
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, anchor_box, groundtruth_box):
|
|
|
|
|
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0])
|
|
|
|
|
validator.check('anchor_box shape[1]', anchor_box[1], '', 4)
|
|
|
|
|
validator.check('groundtruth_box shape[1]', groundtruth_box[1], '', 4)
|
|
|
|
|
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ,
|
|
|
|
|
self.name)
|
|
|
|
|
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name)
|
|
|
|
|
return anchor_box
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, anchor_box, groundtruth_box):
|
|
|
|
|
args = {"anchor_box": anchor_box,
|
|
|
|
|
"groundtruth_box": groundtruth_box
|
|
|
|
|
}
|
|
|
|
|
validator.check_type_same(args, mstype.number_type)
|
|
|
|
|
args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return anchor_box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -126,26 +125,24 @@ class BoundingBoxDecode(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
|
|
|
|
|
validator.check_type('means', means, [tuple])
|
|
|
|
|
validator.check_type('stds', stds, [tuple])
|
|
|
|
|
validator.check_type('wh_ratio_clip', wh_ratio_clip, [float])
|
|
|
|
|
validator.check("means", len(means), '', 4)
|
|
|
|
|
validator.check("stds", len(stds), '', 4)
|
|
|
|
|
validator.check_value_type('means', means, [tuple], self.name)
|
|
|
|
|
validator.check_value_type('stds', stds, [tuple], self.name)
|
|
|
|
|
validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
|
|
|
|
|
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
|
|
|
|
|
if max_shape is not None:
|
|
|
|
|
validator.check_type('max_shape', max_shape, [tuple])
|
|
|
|
|
validator.check("max_shape", len(max_shape), '', 2)
|
|
|
|
|
validator.check_value_type('max_shape', max_shape, [tuple], self.name)
|
|
|
|
|
validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, anchor_box, deltas):
|
|
|
|
|
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0])
|
|
|
|
|
validator.check('anchor_box shape[1]', anchor_box[1], '', 4)
|
|
|
|
|
validator.check('deltas shape[1]', deltas[1], '', 4)
|
|
|
|
|
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name)
|
|
|
|
|
return anchor_box
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, anchor_box, deltas):
|
|
|
|
|
args = {"anchor_box": anchor_box,
|
|
|
|
|
"deltas": deltas
|
|
|
|
|
}
|
|
|
|
|
validator.check_type_same(args, mstype.number_type)
|
|
|
|
|
args = {"anchor_box": anchor_box, "deltas": deltas}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return anchor_box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -168,10 +165,10 @@ class CheckValid(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, bboxes_shape, metas_shape):
|
|
|
|
|
validator.check_shape_length("bboxes shape length", len(bboxes_shape), 2, Rel.EQ)
|
|
|
|
|
validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ)
|
|
|
|
|
validator.check_shape_length("img_metas shape length", len(metas_shape), 1, Rel.EQ)
|
|
|
|
|
validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ)
|
|
|
|
|
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name)
|
|
|
|
|
return bboxes_shape[:-1]
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, bboxes_type, metas_type):
|
|
|
|
@ -221,18 +218,16 @@ class IOU(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, anchor_boxes, gt_boxes):
|
|
|
|
|
validator.check('gt_boxes shape[1]', gt_boxes[1], '', 4)
|
|
|
|
|
validator.check('anchor_boxes shape[1]', anchor_boxes[1], '', 4)
|
|
|
|
|
validator.check('anchor_boxes rank', len(anchor_boxes), '', 2)
|
|
|
|
|
validator.check('gt_boxes rank', len(gt_boxes), '', 2)
|
|
|
|
|
validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name)
|
|
|
|
|
iou = [gt_boxes[0], anchor_boxes[0]]
|
|
|
|
|
return iou
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, anchor_boxes, gt_boxes):
|
|
|
|
|
validator.check_subclass("anchor_boxes", anchor_boxes, mstype.tensor)
|
|
|
|
|
validator.check_subclass("gt_boxes", gt_boxes, mstype.tensor)
|
|
|
|
|
args = {"anchor_boxes": anchor_boxes, "gt_boxes": gt_boxes}
|
|
|
|
|
validator.check_type_same(args, (mstype.float16,))
|
|
|
|
|
validator.check_tensor_type_same(args, (mstype.float16,), self.name)
|
|
|
|
|
return anchor_boxes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -270,7 +265,7 @@ class MakeRefKey(Primitive):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, tag):
|
|
|
|
|
validator.check_type('tag', tag, (str,))
|
|
|
|
|
validator.check_value_type('tag', tag, (str,), self.name)
|
|
|
|
|
|
|
|
|
|
def __call__(self):
|
|
|
|
|
pass
|
|
|
|
|