|
|
|
@ -431,7 +431,7 @@ class Reshape(PrimitiveWithInfer):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Shape(Primitive):
|
|
|
|
|
class Shape(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Returns the shape of input tensor.
|
|
|
|
|
|
|
|
|
@ -452,6 +452,13 @@ class Shape(Primitive):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""Initialize Shape"""
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x):
|
|
|
|
|
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
|
|
|
|
out = {'shape': (),
|
|
|
|
|
'dtype': mstype.tuple_,
|
|
|
|
|
'value': tuple(x['shape'])}
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicShape(Primitive):
|
|
|
|
|
"""
|
|
|
|
|