|
|
|
@ -711,61 +711,6 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
|
|
|
|
return tuple_index_value->cast<Int32ImmPtr>()->value();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Judge whether the node is a loss, and if there are multiple outputs,
|
|
|
|
|
// get which output is a grad according to the tuple getitem.
|
|
|
|
|
// Currently, it is not supported that the sens is a tuple.
|
|
|
|
|
LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(loss_node);
|
|
|
|
|
FuncGraphPtr sub_graph = loss_node->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
|
|
|
CNodePtr return_node = sub_graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node);
|
|
|
|
|
if (return_node->inputs().size() < 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2";
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr pre_node = return_node->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_node);
|
|
|
|
|
|
|
|
|
|
LossNodeInfo node_info;
|
|
|
|
|
|
|
|
|
|
// return -> cast
|
|
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode);
|
|
|
|
|
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
pre_node = pre_cnode->input(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return -> loss
|
|
|
|
|
if (pre_node == loss_node) {
|
|
|
|
|
node_info.has_tuple_getitem = false;
|
|
|
|
|
node_info.dout_index = 0;
|
|
|
|
|
return node_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return -> tuple_getitem -> loss
|
|
|
|
|
auto cnode = pre_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto current_value = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(current_value);
|
|
|
|
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(current_prim);
|
|
|
|
|
// size of common cnode is larger than 1
|
|
|
|
|
if (cnode->inputs().size() < 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) {
|
|
|
|
|
// size of tuple_getitem cnode is 3
|
|
|
|
|
auto tuple_index = GetTupleGetItemIndex(cnode);
|
|
|
|
|
node_info.has_tuple_getitem = true;
|
|
|
|
|
node_info.dout_index = tuple_index;
|
|
|
|
|
return node_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid loss";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
size_t node_size = node->inputs().size();
|
|
|
|
@ -958,13 +903,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
|
|
|
|
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs) {
|
|
|
|
|
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(distribute_operator);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
|
|
bool is_loss_cnode =
|
|
|
|
|
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
|
|
|
|
|
[node](const std::pair<CNodePtr, CNodePtr> &element) { return element.second == node; });
|
|
|
|
|
[node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
|
|
|
|
|
|
|
|
|
|
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
|
|
|
|
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
|
|
|
@ -1819,7 +1764,20 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|
|
|
|
CNodePtr HandleDependLoss(const CNodePtr &cnode) {
|
|
|
|
|
// Handle return->depend->loss
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim->name() == DEPEND) {
|
|
|
|
|
auto depend_before = cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(depend_before);
|
|
|
|
|
return HandleDependLoss(depend_before);
|
|
|
|
|
}
|
|
|
|
|
return cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
|
|
|
|
LossNodeInfo loss_node_info;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
CNodePtr return_node = func_graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node);
|
|
|
|
@ -1831,9 +1789,9 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|
|
|
|
|
|
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>();
|
|
|
|
|
if (pre_cnode == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
return loss_node_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pre_cnode = HandleDependLoss(pre_cnode);
|
|
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
// return -> cast
|
|
|
|
|
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
@ -1845,7 +1803,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|
|
|
|
// notice: the GetNext op has not input
|
|
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
|
|
|
|
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
|
|
|
|
return pre_cnode;
|
|
|
|
|
loss_node_info.loss_node = pre_cnode;
|
|
|
|
|
return loss_node_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// size of common cnode is larger than 1
|
|
|
|
@ -1855,36 +1814,34 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|
|
|
|
|
|
|
|
|
// return -> tuple_getitem -> loss
|
|
|
|
|
if (current_prim->name() == TUPLE_GETITEM) {
|
|
|
|
|
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
|
|
|
|
|
AnfNodePtr pre_pre_node = pre_cnode->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_pre_node);
|
|
|
|
|
|
|
|
|
|
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
|
|
|
|
|
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
|
PrimitivePtr prim = value->value()->cast<PrimitivePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
MS_LOG(DEBUG) << "The loss name is " << prim->name();
|
|
|
|
|
return pre_pre_cnode;
|
|
|
|
|
loss_node_info.has_tuple_getitem = true;
|
|
|
|
|
loss_node_info.dout_index = tuple_index;
|
|
|
|
|
loss_node_info.loss_node = pre_pre_cnode;
|
|
|
|
|
return loss_node_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return -> make_tuple
|
|
|
|
|
if (current_prim->name() == MAKE_TUPLE) {
|
|
|
|
|
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
|
|
|
|
|
return nullptr;
|
|
|
|
|
return loss_node_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return -> loss
|
|
|
|
|
loss_node_info.loss_node = pre_cnode;
|
|
|
|
|
MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
|
|
|
|
|
return pre_cnode;
|
|
|
|
|
return loss_node_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
|
|
|
|
|
TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
|
|
|
|
|
TensorLayouts ret;
|
|
|
|
|
auto loss_cnode = node_info.loss_node;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(loss_cnode);
|
|
|
|
|
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
|
|
LossNodeInfo node_info = GetLossNodeInfo(node);
|
|
|
|
|
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
|
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
|
|
|
@ -2086,9 +2043,9 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
|
|
|
|
|
return graph_set;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
|
|
|
|
|
void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
|
|
|
|
|
CNodePtr sens_node = sens_loss_pair.first;
|
|
|
|
|
CNodePtr loss_node = sens_loss_pair.second;
|
|
|
|
|
auto loss_node = sens_loss_pair.second;
|
|
|
|
|
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
|
|
|
|
|
if (!loss_grad_layout.empty()) {
|
|
|
|
|
SplitSens(sens_node, loss_grad_layout[0]);
|
|
|
|
@ -2096,9 +2053,9 @@ void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
|
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) {
|
|
|
|
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root);
|
|
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs;
|
|
|
|
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
|
|
|
|
|
for (auto &node : root->nodes()) {
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
@ -2140,12 +2097,12 @@ std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &
|
|
|
|
|
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
|
|
|
|
|
}
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
|
|
|
|
|
auto loss_cnode = FindLossCNode(func_graph);
|
|
|
|
|
if (loss_cnode == nullptr) {
|
|
|
|
|
auto loss_node_info = FindLossCNode(func_graph);
|
|
|
|
|
if (loss_node_info.loss_node == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "Can not find the loss cnode";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode);
|
|
|
|
|
std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
|
|
|
|
|
sens_loss_pairs.push_back(sens_loss_pair);
|
|
|
|
|
}
|
|
|
|
|
return sens_loss_pairs;
|
|
|
|
@ -2157,7 +2114,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
TensorRedistribution tensor_redistribution;
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs = GetSensLossPairs(root);
|
|
|
|
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
|
|
|
|
|
bool has_backward = !sens_loss_pairs.empty();
|
|
|
|
|
// split sens must before inserting the operators.
|
|
|
|
|
for (auto &pair : sens_loss_pairs) {
|
|
|
|
@ -2372,7 +2329,7 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
|
|
|
|
|
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
std::vector<AnfNodePtr> root_forward_nodes;
|
|
|
|
|
auto loss_cnode = FindLossCNode(graph);
|
|
|
|
|
auto loss_cnode = FindLossCNode(graph).loss_node;
|
|
|
|
|
if (loss_cnode == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "Can not find the loss cnode";
|
|
|
|
|
return root_forward_nodes;
|
|
|
|
|