diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 6ee6156118..5086ea7d21 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -23,11 +23,17 @@ namespace mindspore { namespace abstract { std::vector GetDependsFormMap(const CNodePtr &cnode) { - constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum"; - constexpr auto kUnsortedSegmentMin = "UnsortedSegmentMin"; - constexpr auto kUnsortedSegmentMax = "UnsortedSegmentMax"; + const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name(); + const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name(); + const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name(); + const auto kGather = prim::kPrimGather->name(); + const auto kGatherV2 = prim::kPrimGatherV2->name(); + const auto kDynamicShape = prim::kPrimDynamicShape->name(); + const auto kRange = prim::kPrimRange->name(); static std::map> dynamic_shape_depends = { - {kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}}; + {kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}, {kGather, {2}}, + {kGatherV2, {2}}, {kDynamicShape, {0}}, {kRange, {0, 1, 2}}, + }; MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().empty()) { MS_LOG(EXCEPTION) << "Invalid inputs"; diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b3943b9464..edf72ca33e 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -593,7 +593,6 @@ class DynamicShape(Primitive): """init Shape""" self.init_prim_io_names(inputs=['tensor'], outputs=['output']) self.add_prim_attr('is_dynamic_shape', True) - self.add_prim_attr("dynamic_shape_depends", [0]) class Squeeze(PrimitiveWithInfer): @@ -811,7 +810,6 @@ class Gather(PrimitiveWithCheck): def __init__(self): """Initialize index_select""" self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) - self.add_prim_attr("dynamic_shape_depends", [2]) def __check__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) @@ -836,7 +834,6 @@ class GatherV2(PrimitiveWithCheck): def __init__(self): """Initialize index_select""" self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) - self.add_prim_attr("dynamic_shape_depends", [2]) def __check__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) @@ -1987,7 +1984,6 @@ class UnsortedSegmentMin(PrimitiveWithCheck): def __init__(self): """Initialize UnsortedSegmentMin""" self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) - self.add_prim_attr("dynamic_shape_depends", [2]) def __check__(self, x, segment_ids, num_segments): x_shape = x['shape'] @@ -2043,7 +2039,6 @@ class UnsortedSegmentMax(PrimitiveWithCheck): def __init__(self): """Initialize UnsortedSegmentMax""" self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) - self.add_prim_attr("dynamic_shape_depends", [2]) def __check__(self, x, segment_ids, num_segments): x_shape = x['shape'] @@ -4980,10 +4975,6 @@ class Range(PrimitiveWithCheck): self.maxlen = maxlen self.add_prim_attr('maxlen', maxlen) - self.add_prim_attr("dynamic_shape_depends", [0]) - self.add_prim_attr("dynamic_shape_depends", [1]) - self.add_prim_attr("dynamic_shape_depends", [2]) - def check_shape(self, start_shape, limit_shape, delta_shape): validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name) validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)