add validator for InvertPermutation

pull/1763/head
jiangjinsheng 5 years ago
parent 32a72c1979
commit 317ec43d88

@ -1039,6 +1039,8 @@ class InvertPermutation(PrimitiveWithInfer):
def __infer__(self, x):
x_shp = x['shape']
x_value = x['value']
if x_value is None:
raise ValueError(f'For \'{self.name}\' the input value must be const.')
validator.check_value_type("shape", x_shp, [tuple, list], self.name)
if mstype.issubclass_(x['dtype'], mstype.tensor):
validator.check('x dimension', len(x_shp), '', 1, Rel.EQ, self.name)
@ -1047,6 +1049,10 @@ class InvertPermutation(PrimitiveWithInfer):
z = [x_value[i] for i in range(len(x_value))]
z.sort()
validator.check(f'value length', len(x_value), f'unique value length', len(set(x_value)), Rel.EQ, self.name)
validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name)
y = [None] * len(x_value)
for i, value in enumerate(x_value):
validator.check_value_type("input[%d]" % i, value, [int], self.name)

Loading…
Cancel
Save