|
|
|
@ -618,19 +618,11 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
|
|
|
|
|
for (auto &replace_input : replace_graph->first) {
|
|
|
|
|
auto pre_node = node->input(IntToSize(replace_input.second));
|
|
|
|
|
manager->SetEdge(replace_input.first, 1, pre_node);
|
|
|
|
|
auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replace_input_cnode);
|
|
|
|
|
(void)replace_input_cnode->set_operator_info(node->operator_info());
|
|
|
|
|
replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
|
|
|
|
}
|
|
|
|
|
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
|
|
|
|
|
auto replace_output = replace_graph->second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replace_output);
|
|
|
|
|
(void)manager->Replace(node, replace_output);
|
|
|
|
|
CNodePtr replace_output_cnode = replace_graph->second->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replace_output_cnode);
|
|
|
|
|
(void)replace_output_cnode->set_operator_info(node->operator_info());
|
|
|
|
|
replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
|
|
|
@ -1994,14 +1986,27 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|
|
|
|
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// StepReplace
|
|
|
|
|
StepReplace(distribute_operator, cnode);
|
|
|
|
|
|
|
|
|
|
HandleSpecialNode(distribute_operator, cnode);
|
|
|
|
|
} else if (IsValueNode<Tensor>(node)) {
|
|
|
|
|
StepSplitTensor(node, manager);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// StepReplace
|
|
|
|
|
StepReplace(distribute_operator, cnode);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|