diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc index 3f99820633..10a8179c76 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc @@ -187,7 +187,7 @@ class CNodeDecoder { if ((node.first)->isa()) { auto parameter = (node.first)->cast(); bool is_weight = AnfAlgo::IsParameterWeight(parameter); - kernel_info->SetFeatureMapFlag(!is_weight); + kernel_info->set_feature_map_flag(!is_weight); if (!is_weight) { feature_map_input_indexs.push_back(index - 1); } @@ -200,7 +200,7 @@ class CNodeDecoder { AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_); } if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { - kernel_info->SetFeatureMapFlag(true); + kernel_info->set_feature_map_flag(true); } if (AnfAlgo::IsRealCNodeKernel(cnode_)) { AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 1a7aeacc7c..1306d49931 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -31,6 +31,8 @@ namespace session { namespace { constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; +const std::set kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(), + prim::kPrimAssignSub->name()}; void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, std::unordered_set *visited_nodes) { MS_EXCEPTION_IF_NULL(node); @@ -417,21 +419,41 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { } } +void KernelGraph::ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const { + if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) { + MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map " + "flag but got the node :" + << cnode->DebugString(); + } + auto input_node = AnfAlgo::GetInputNode(cnode, 0); + auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1); + if (AnfAlgo::IsFeatureMapOutput(input_node)) { + return; + } + if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) { + auto kernel_info = static_cast(input_node->kernel_info()); + kernel_info->set_feature_map_flag(true); + } +} + void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(node); auto kernel_info = std::make_shared(); node->set_kernel_info(kernel_info); if (node->isa()) { + if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) { + ResetAssignInputFeaatureMapFlag(node->cast()); + } std::vector feature_map_input_indexs; - kernel_info->SetFeatureMapFlag(false); + kernel_info->set_feature_map_flag(false); for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { if (AnfAlgo::IsFeatureMapInput(node, index)) { - kernel_info->SetFeatureMapFlag(true); + kernel_info->set_feature_map_flag(true); feature_map_input_indexs.push_back(index); } } if (AnfAlgo::GetInputTensorNum(node) == 0) { - kernel_info->SetFeatureMapFlag(true); + kernel_info->set_feature_map_flag(true); } if (AnfAlgo::IsRealKernel(node)) { // if the node only has the primitive(such as getNext) or the node's input has a feature map input @@ -446,7 +468,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { std::vector types; std::vector formats = {kOpFormat_DEFAULT}; if (node->isa()) { - kernel_info->SetFeatureMapFlag(false); + kernel_info->set_feature_map_flag(false); types.emplace_back(kTypeUnknown); auto value_node = node->cast(); SyncDeviceInfoToValueNode(value_node, &formats, &types); @@ -455,7 +477,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { auto parameter = node->cast(); MS_EXCEPTION_IF_NULL(parameter); bool is_weight = AnfAlgo ::IsParameterWeight(parameter); - kernel_info->SetFeatureMapFlag(!is_weight); + kernel_info->set_feature_map_flag(!is_weight); types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); } // set parameter initaial device data type diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 038d7337e5..50334e58c2 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -100,6 +100,7 @@ class KernelGraph : public FuncGraph { CNodePtr NewCNode(const std::vector &inputs) override; void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); CNodePtr NewCNode(const CNodePtr &cnode); + void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const; ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 8986c84599..8385e45a35 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -838,7 +838,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker new_value_node->set_abstract(value_node->abstract()); // create new kernel_info of new value_node auto kernel_info = std::make_shared(); - kernel_info->SetFeatureMapFlag(false); new_value_node->set_kernel_info(kernel_info); // create kernel_build_info for new value node auto kernel_build_info_builder = std::make_shared(); diff --git a/mindspore/ccsrc/runtime/device/kernel_info.h b/mindspore/ccsrc/runtime/device/kernel_info.h index e9d997cb5e..7f8d17e0aa 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.h +++ b/mindspore/ccsrc/runtime/device/kernel_info.h @@ -48,7 +48,7 @@ class KernelInfo : public KernelInfoDevice { void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { select_kernel_build_info_ = select_kernel_build_info; } - void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } + void set_feature_map_flag(bool flag) { is_feature_map_ = flag; } const DeviceAddress *GetOutputAddr(size_t index) const; DeviceAddressPtr GetMutableOutputAddr(size_t index) const; bool OutputAddrExist(size_t index) const;