|
|
|
@ -133,8 +133,21 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_node);
|
|
|
|
|
auto inputs = c_node->inputs();
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs;
|
|
|
|
|
(void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
|
|
|
|
|
[this](const AnfNodePtr &inp) -> AnfNodePtr { return ReplicateDisconnectedNode(inp); });
|
|
|
|
|
(void)std::transform(
|
|
|
|
|
inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr {
|
|
|
|
|
auto new_inp = ReplicateDisconnectedNode(inp);
|
|
|
|
|
// Refer the comments in BuildReplacedNode.
|
|
|
|
|
if (inp->isa<CNode>()) {
|
|
|
|
|
auto c_inp = inp->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_inp);
|
|
|
|
|
auto c_new_inp = new_inp->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_new_inp);
|
|
|
|
|
MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString();
|
|
|
|
|
c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp);
|
|
|
|
|
}
|
|
|
|
|
return new_inp;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
auto c_new_node = new_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_new_node);
|
|
|
|
|
c_new_node->set_inputs(new_inputs);
|
|
|
|
|