Add prim name to error message for other operators left

pull/367/head
fary86 5 years ago
parent 72f42fc37c
commit 6770c66ed9

File diff suppressed because it is too large Load Diff

@ -15,7 +15,8 @@
"""comm_ops"""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...communication.management import get_rank, get_group_size, GlobalComm, get_group
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
@ -148,12 +149,10 @@ class AllGather(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
if not isinstance(get_group(group), str):
raise TypeError("The group of AllGather should be str.")
validator.check_value_type('group', get_group(group), (str,), self.name)
self.rank = get_rank(get_group(group))
self.rank_size = get_group_size(get_group(group))
if self.rank >= self.rank_size:
raise ValueError("The rank of AllGather should be less than the rank_size.")
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', get_group(group))
@ -163,7 +162,7 @@ class AllGather(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AllGather does not support 'Bool' as the dtype of input!")
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
@ -205,10 +204,8 @@ class ReduceScatter(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
if not isinstance(op, type(ReduceOp.SUM)):
raise TypeError("The operation of ReduceScatter should be {}.".format(type(ReduceOp.SUM)))
if not isinstance(get_group(group), str):
raise TypeError("The group of ReduceScatter should be str.")
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', get_group(group), (str,), self.name)
self.op = op
self.rank_size = get_group_size(get_group(group))
self.add_prim_attr('rank_size', self.rank_size)
@ -216,13 +213,13 @@ class ReduceScatter(PrimitiveWithInfer):
def infer_shape(self, x_shape):
if x_shape[0] % self.rank_size != 0:
raise ValueError("The first dimension of x should be divided by rank_size.")
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
x_shape[0] = int(x_shape[0]/self.rank_size)
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!")
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
@ -270,10 +267,8 @@ class Broadcast(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
if not isinstance(root_rank, int):
raise TypeError("The root_rank of Broadcast should be int.")
if not isinstance(get_group(group), str):
raise TypeError("The group of Broadcast should be str.")
validator.check_value_type('root_rank', root_rank, (int,), self.name)
validator.check_value_type('group', get_group(group), (str,), self.name)
self.add_prim_attr('group', get_group(group))
def infer_shape(self, x_shape):
@ -281,7 +276,7 @@ class Broadcast(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!")
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype
@ -311,8 +306,7 @@ class _AlltoAll(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
"""init AlltoAll"""
if not isinstance(get_group(group), str):
raise TypeError("The group of AllGather should be str.")
validator.check_value_type('group', get_group(group), (str,), self.name)
self.split_count = split_count
self.split_dim = split_dim
self.concat_dim = concat_dim
@ -325,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!")
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
@ -420,6 +414,6 @@ class _GetTensorSlice(PrimitiveWithInfer):
def infer_value(self, x, dev_mat, tensor_map):
from mindspore.parallel._tensor import _load_tensor
validator.check_type("dev_mat", dev_mat, [tuple])
validator.check_type("tensor_map", tensor_map, [tuple])
validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
return _load_tensor(x, dev_mat, tensor_map)

@ -16,7 +16,8 @@
"""control_ops"""
from ...common import dtype as mstype
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@ -123,11 +124,11 @@ class GeSwitch(PrimitiveWithInfer):
raise NotImplementedError
def infer_shape(self, data, pred):
validator.check_scalar_shape_input("pred", pred)
validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name)
return (data, data)
def infer_dtype(self, data_type, pred_type):
validator.check_type("pred", pred_type, [type(mstype.bool_)])
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
return (data_type, data_type)

@ -14,7 +14,7 @@
# ============================================================================
"""debug_ops"""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer
@ -219,5 +219,5 @@ class Print(PrimitiveWithInfer):
def infer_dtype(self, *inputs):
for dtype in inputs:
validator.check_subclass("input", dtype, (mstype.tensor, mstype.string))
validator.check_subclass("input", dtype, (mstype.tensor, mstype.string), self.name)
return mstype.int32

@ -16,7 +16,7 @@
"""Other operators."""
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._checkparam import ParamValidator as validator, Rel
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@ -82,22 +82,21 @@ class BoundingBoxEncode(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
validator.check_type('means', means, [tuple])
validator.check_type('stds', stds, [tuple])
validator.check("means len", len(means), '', 4)
validator.check("stds len", len(stds), '', 4)
validator.check_value_type('means', means, [tuple], self.name)
validator.check_value_type('stds', stds, [tuple], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
def infer_shape(self, anchor_box, groundtruth_box):
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0])
validator.check('anchor_box shape[1]', anchor_box[1], '', 4)
validator.check('groundtruth_box shape[1]', groundtruth_box[1], '', 4)
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ,
self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name)
return anchor_box
def infer_dtype(self, anchor_box, groundtruth_box):
args = {"anchor_box": anchor_box,
"groundtruth_box": groundtruth_box
}
validator.check_type_same(args, mstype.number_type)
args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return anchor_box
@ -126,26 +125,24 @@ class BoundingBoxDecode(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
validator.check_type('means', means, [tuple])
validator.check_type('stds', stds, [tuple])
validator.check_type('wh_ratio_clip', wh_ratio_clip, [float])
validator.check("means", len(means), '', 4)
validator.check("stds", len(stds), '', 4)
validator.check_value_type('means', means, [tuple], self.name)
validator.check_value_type('stds', stds, [tuple], self.name)
validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
if max_shape is not None:
validator.check_type('max_shape', max_shape, [tuple])
validator.check("max_shape", len(max_shape), '', 2)
validator.check_value_type('max_shape', max_shape, [tuple], self.name)
validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name)
def infer_shape(self, anchor_box, deltas):
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0])
validator.check('anchor_box shape[1]', anchor_box[1], '', 4)
validator.check('deltas shape[1]', deltas[1], '', 4)
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name)
return anchor_box
def infer_dtype(self, anchor_box, deltas):
args = {"anchor_box": anchor_box,
"deltas": deltas
}
validator.check_type_same(args, mstype.number_type)
args = {"anchor_box": anchor_box, "deltas": deltas}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return anchor_box
@ -168,10 +165,10 @@ class CheckValid(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output'])
def infer_shape(self, bboxes_shape, metas_shape):
validator.check_shape_length("bboxes shape length", len(bboxes_shape), 2, Rel.EQ)
validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ)
validator.check_shape_length("img_metas shape length", len(metas_shape), 1, Rel.EQ)
validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ)
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name)
validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name)
validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name)
validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name)
return bboxes_shape[:-1]
def infer_dtype(self, bboxes_type, metas_type):
@ -221,18 +218,16 @@ class IOU(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
def infer_shape(self, anchor_boxes, gt_boxes):
validator.check('gt_boxes shape[1]', gt_boxes[1], '', 4)
validator.check('anchor_boxes shape[1]', anchor_boxes[1], '', 4)
validator.check('anchor_boxes rank', len(anchor_boxes), '', 2)
validator.check('gt_boxes rank', len(gt_boxes), '', 2)
validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name)
validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name)
iou = [gt_boxes[0], anchor_boxes[0]]
return iou
def infer_dtype(self, anchor_boxes, gt_boxes):
validator.check_subclass("anchor_boxes", anchor_boxes, mstype.tensor)
validator.check_subclass("gt_boxes", gt_boxes, mstype.tensor)
args = {"anchor_boxes": anchor_boxes, "gt_boxes": gt_boxes}
validator.check_type_same(args, (mstype.float16,))
validator.check_tensor_type_same(args, (mstype.float16,), self.name)
return anchor_boxes
@ -270,7 +265,7 @@ class MakeRefKey(Primitive):
@prim_attr_register
def __init__(self, tag):
validator.check_type('tag', tag, (str,))
validator.check_value_type('tag', tag, (str,), self.name)
def __call__(self):
pass

@ -15,7 +15,7 @@
"""Operators for random."""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
@ -52,16 +52,15 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, count=256, seed=0, seed2=0):
"""Init RandomChoiceWithMask"""
validator.check_type("count", count, [int])
validator.check_integer("count", count, 0, Rel.GT)
validator.check_type('seed', seed, [int])
validator.check_type('seed2', seed2, [int])
validator.check_value_type("count", count, [int], self.name)
validator.check_integer("count", count, 0, Rel.GT, self.name)
validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)
def infer_shape(self, x_shape):
validator.check_shape_length("input_x shape", len(x_shape), 1, Rel.GE)
validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype):
validator.check_subclass('x_dtype', x_dtype, mstype.tensor)
validator.check_typename('x_dtype', x_dtype, [mstype.bool_])
validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_)

Loading…
Cancel
Save