!1128 only call HasNodeAttr for CNodePtr type

Merge pull request !1128 from Margaret_wangrui/master
pull/1128/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a61f8486c6

@ -45,7 +45,8 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
if (anf_node_ptr->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2.";
}
auto pre_node = anf_node_ptr->inputs()[1];
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]);
auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>();
// set clean output addr
if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) {
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs);

@ -546,7 +546,8 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
if (cnode->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
}
auto pre_node = cnode->inputs()[1];
MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
// set clean output address
if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) {
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs);

@ -34,7 +34,8 @@ LabelGotoKernel::~LabelGotoKernel() {}
bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "LabelGotoKernel init";
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, anf_node)) {
auto cnode = anf_node->cast<CNodePtr>();
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
MS_LOG(EXCEPTION) << "LabelGotoKernel has no attr label_index";
}
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);

@ -34,7 +34,8 @@ LabelSetKernel::~LabelSetKernel() {}
bool LabelSetKernel::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "LabelSetKernel init";
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, anf_node)) {
auto cnode = anf_node->cast<CNodePtr>();
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index";
}
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);

@ -38,7 +38,8 @@ LabelSwitchKernel::~LabelSwitchKernel() {}
bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "LabelSwitchKernel init";
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, anf_node)) {
auto cnode = anf_node->cast<CNodePtr>();
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) {
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
}
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);

@ -347,9 +347,13 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
auto nodes = TopoSort(kernel_graph->get_return());
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsRealCNodeKernel(node) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, node)) {
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(node, kOpAttrFusionId);
(*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) {
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(cnode, kOpAttrFusionId);
(*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode);
}
}
}

@ -39,7 +39,7 @@ const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, co
}
// 1. memcpy has attr kAttrLabelForInsertStreamActive
if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, node)) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, memcpy_cnode)) {
MS_LOG(DEBUG) << "node has no label_for_insert_stream_active attr";
return nullptr;
}

@ -63,13 +63,14 @@ const AnfNodePtr InsertMemcpyAsyncForGetNext::Process(const FuncGraphPtr &func_g
return nullptr;
}
if (AnfAlgo::HasNodeAttr(kAttrVisited, node)) {
auto cnode = node->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrVisited, cnode)) {
MS_LOG(DEBUG) << "Node op_name[" << kGetNextOpName << "] has visited.";
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cnode);
return InsertMemcpyAsyncForGetNextOutputs(func_graph, node);
return InsertMemcpyAsyncForGetNextOutputs(func_graph, cnode);
}
} // namespace opt
} // namespace mindspore

@ -29,7 +29,14 @@ int32_t FusionIdAllocator::AllocateFusionId() {
return fusion_id;
}
bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { return AnfAlgo::HasNodeAttr(kAttrFusionId, node); }
bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
return AnfAlgo::HasNodeAttr(kAttrFusionId, cnode);
}
int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) {
if (HasFusionIdAttr(node)) {

@ -246,12 +246,8 @@ void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr
primitive->EraseAttr(key);
}
bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const AnfNodePtr &node) {
bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString();
return false;
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
return primitive->HasAttr(key);

@ -80,7 +80,7 @@ class AnfRuntimeAlgorithm {
// set all attrs from 'from' node to 'to' node
static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
// check whether a cnode has the specified attr.
static bool HasNodeAttr(const std::string &key, const AnfNodePtr &node);
static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
// delete attr of anf node
static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
// get the num of input real_kernel(which can be build and run in device)

Loading…
Cancel
Save