|
|
|
@ -589,7 +589,7 @@ class Squeeze(PrimitiveWithInfer):
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Transpose(PrimitiveWithInfer):
|
|
|
|
|
class Transpose(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Permutes the dimensions of input tensor according to input permutation.
|
|
|
|
|
|
|
|
|
@ -621,32 +621,13 @@ class Transpose(PrimitiveWithInfer):
|
|
|
|
|
"""Initialize Transpose"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, perm):
|
|
|
|
|
x_shape = x['shape']
|
|
|
|
|
p_value = perm['value']
|
|
|
|
|
x_type = x['dtype']
|
|
|
|
|
validator.check_value_type("p_value", p_value, [tuple], self.name)
|
|
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
|
|
|
|
|
|
|
|
|
if len(x_shape) != len(p_value):
|
|
|
|
|
def check_shape(self, x, perm):
|
|
|
|
|
validator.check_value_type("perm", perm, [tuple], self.name)
|
|
|
|
|
if len(x) != len(perm):
|
|
|
|
|
raise ValueError('The dimension of x and perm must be equal.')
|
|
|
|
|
|
|
|
|
|
tmp = list(p_value)
|
|
|
|
|
for i, dim in enumerate(p_value):
|
|
|
|
|
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
|
|
|
|
|
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
|
|
|
|
|
tmp.remove(dim)
|
|
|
|
|
if dim in tmp:
|
|
|
|
|
raise ValueError('The value of perm is wrong.')
|
|
|
|
|
|
|
|
|
|
out_shapes = []
|
|
|
|
|
for i in p_value:
|
|
|
|
|
out_shapes.append(x_shape[i])
|
|
|
|
|
out = {'shape': tuple(out_shapes),
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': None}
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, x, perm):
|
|
|
|
|
validator.check_subclass("x", x, mstype.tensor, self.name)
|
|
|
|
|
|
|
|
|
|
class Unique(Primitive):
|
|
|
|
|
"""
|
|
|
|
|