diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 044ab3a7f1..177e9c510b 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -431,7 +431,7 @@ class Reshape(PrimitiveWithInfer): return out -class Shape(Primitive): +class Shape(PrimitiveWithInfer): """ Returns the shape of input tensor. @@ -452,6 +452,13 @@ class Shape(Primitive): def __init__(self): """Initialize Shape""" + def __infer__(self, x): + validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) + out = {'shape': (), + 'dtype': mstype.tuple_, + 'value': tuple(x['shape'])} + return out + class DynamicShape(Primitive): """