|
|
|
@ -62,6 +62,16 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
|
|
|
|
return kernel_with_index;
|
|
|
|
return kernel_with_index;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index,
|
|
|
|
|
|
|
|
const size_t input_index) {
|
|
|
|
|
|
|
|
// record the ref_pair
|
|
|
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
|
|
|
session::AnfWithOutIndex final_pair = std::make_pair(cnode, output_index);
|
|
|
|
|
|
|
|
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0);
|
|
|
|
|
|
|
|
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item,
|
|
|
|
void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item,
|
|
|
|
const AnfNodePtr &final_node, size_t final_index,
|
|
|
|
const AnfNodePtr &final_node, size_t final_index,
|
|
|
|
const session::KernelWithIndex &origin_pair) {
|
|
|
|
const session::KernelWithIndex &origin_pair) {
|
|
|
|
@ -88,6 +98,7 @@ void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cno
|
|
|
|
AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index,
|
|
|
|
AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index,
|
|
|
|
size_t input_index, const AnfNodePtr &get_item) {
|
|
|
|
size_t input_index, const AnfNodePtr &get_item) {
|
|
|
|
AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item);
|
|
|
|
AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item);
|
|
|
|
|
|
|
|
bool need_refresh_ref_addr = false;
|
|
|
|
size_t final_index = output_index;
|
|
|
|
size_t final_index = output_index;
|
|
|
|
AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index);
|
|
|
|
AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index);
|
|
|
|
session::KernelWithIndex origin_pair;
|
|
|
|
session::KernelWithIndex origin_pair;
|
|
|
|
@ -109,6 +120,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|
|
|
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
|
|
|
|
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
|
|
|
|
RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
|
|
|
|
RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
|
|
|
|
final_index = 0;
|
|
|
|
final_index = 0;
|
|
|
|
|
|
|
|
need_refresh_ref_addr = true;
|
|
|
|
MS_EXCEPTION_IF_NULL(final_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(final_node);
|
|
|
|
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
|
|
|
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -119,15 +131,19 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|
|
|
MS_EXCEPTION_IF_NULL(final_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(final_node);
|
|
|
|
final_node->set_scope(cnode->scope());
|
|
|
|
final_node->set_scope(cnode->scope());
|
|
|
|
final_index = 0;
|
|
|
|
final_index = 0;
|
|
|
|
|
|
|
|
need_refresh_ref_addr = true;
|
|
|
|
MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString();
|
|
|
|
MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// add ref pair
|
|
|
|
// add ref pair
|
|
|
|
AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair);
|
|
|
|
AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair);
|
|
|
|
|
|
|
|
if (need_refresh_ref_addr) {
|
|
|
|
|
|
|
|
AddRefNodePairToKernelGraph(func_graph, cnode, output_index, input_index);
|
|
|
|
|
|
|
|
}
|
|
|
|
// insert depend
|
|
|
|
// insert depend
|
|
|
|
if (origin_format != cur_format || origin_type != cur_type) {
|
|
|
|
if (origin_format != cur_format || origin_type != cur_type) {
|
|
|
|
std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node};
|
|
|
|
std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node};
|
|
|
|
final_node = func_graph->NewCNode(depend_nodes);
|
|
|
|
final_node = func_graph->NewCNode(depend_nodes);
|
|
|
|
MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString();
|
|
|
|
MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return final_node;
|
|
|
|
return final_node;
|
|
|
|
|