diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc index 39ab8e4721..bf271d6bcc 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc @@ -100,13 +100,13 @@ void FeedTeOpTensorOutputArg(const NotNull &cnode, void FeedTeOpConstTensor(const NotNull &cnode, const std::map &depend_tensor_map, NotNull *> const_inputs) { MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope(); - if (!AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode.get())) { + auto depends_list_me = abstract::GetDependsFormMap(cnode); + if (depends_list_me.empty()) { MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope(); return; } std::vector depends_list; - std::vector depends_list_me = AnfAlgo::GetNodeAttr>(cnode.get(), kDynamicShapeDepends); (void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depends_list), [](const int64_t &value) { return static_cast(value); }); for (auto index : depends_list) { diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h index 066d21dd05..17c4262f19 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h +++ b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h @@ -25,6 +25,7 @@ #include "ir/anf.h" #include "ir/tensor.h" #include "register/op_tiling.h" +#include "abstract/primitive_infer_map.h" namespace mindspore { namespace device { diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc index 1ed054895f..79c1de652a 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc @@ -39,16 +39,14 @@ void DynamicKernel::Initialize() { is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape); is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape); - auto have_depends = AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode_ptr_); - if (!have_depends) { + auto ret = abstract::GetDependsFormMap(cnode_ptr_); + if (ret.empty()) { MS_LOG(DEBUG) << "No dynamic_shape_depends found"; return; } MS_LOG(INFO) << "Have depends"; - std::vector depends_list_me = AnfAlgo::GetNodeAttr>(cnode_ptr_, kDynamicShapeDepends); - (void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depend_list_), + (void)std::transform(ret.begin(), ret.end(), std::back_inserter(depend_list_), [](const int64_t &value) { return static_cast(value); }); - MS_LOG(INFO) << "Init End"; } diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h index 637bdc410e..fcf3b0ba8e 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h @@ -23,6 +23,7 @@ #include #include "ir/anf.h" #include "ir/tensor.h" +#include "abstract/primitive_infer_map.h" namespace mindspore { namespace device { diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index dd94ebfc57..b356662590 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -22,6 +22,25 @@ namespace mindspore { namespace abstract { +std::vector GetDependsFormMap(const CNodePtr &cnode) { + constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum"; + constexpr auto kUnsortedSegmentMin = "UnsortedSegmentMin"; + constexpr auto kUnsortedSegmentMax = "UnsortedSegmentMax"; + static std::map> dynamic_shape_depends = { + {kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}}; + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Invalid inputs"; + } + auto primitive = GetValueNode(cnode->inputs()[0]); + MS_EXCEPTION_IF_NULL(primitive); + auto iter = dynamic_shape_depends.find(primitive->ToString()); + if (iter != dynamic_shape_depends.end()) { + return iter->second; + } + return {}; +} + PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { static PrimitiveEvalImplMap prim_eval_implement_map = { // Statements diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index 1274e1ac50..77329ce901 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -21,6 +21,8 @@ #include "ir/primitive.h" #include "base/core_ops.h" #include "abstract/abstract_value.h" +#include "ir/anf.h" + namespace mindspore { namespace abstract { using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, @@ -35,6 +37,8 @@ using PrimitiveEvalImplMap = PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); +std::vector GetDependsFormMap(const CNodePtr &cnode); + void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); class RegisterStandardPrimitiveEvalHelper { diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a792910808..d313ba0e19 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1892,7 +1892,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer): def __init__(self): """Initialize UnsortedSegmentSum""" self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) - self.add_prim_attr("dynamic_shape_depends", [2]) def __infer__(self, x, segment_ids, num_segments): x_type = x['dtype']