|
|
|
@ -1571,7 +1571,7 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
|
|
|
|
|
const AnfNodePtr &node) {
|
|
|
|
|
const AnfNodePtr &node, const std::string &op_name) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(res.first);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = res.first->cast<CNodePtr>();
|
|
|
|
@ -1579,10 +1579,9 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_prim);
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
|
|
|
|
Operator op;
|
|
|
|
|
CNodePtr allgather;
|
|
|
|
|
if (grad_accumulation_step > 1) {
|
|
|
|
|
if (op_name == MINI_STEP_ALL_GATHER) {
|
|
|
|
|
op = CreateMiniStepAllGatherOp(group);
|
|
|
|
|
auto param_name = node->cast<ParameterPtr>()->name();
|
|
|
|
|
if (cnode_prim->name() == CAST) {
|
|
|
|
@ -1613,21 +1612,41 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|
|
|
|
}
|
|
|
|
|
FuncGraphManagerPtr manager = root->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
|
|
|
|
std::string op_name;
|
|
|
|
|
if (grad_accumulation_step > 1) {
|
|
|
|
|
op_name = MINI_STEP_ALL_GATHER;
|
|
|
|
|
} else {
|
|
|
|
|
op_name = ALL_GATHER;
|
|
|
|
|
}
|
|
|
|
|
auto param_sub_set = manager->node_users()[parameter];
|
|
|
|
|
bool insert_flag = false;
|
|
|
|
|
for (auto ¶m_pair : param_sub_set) {
|
|
|
|
|
auto cnode = param_pair.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->in_forward_flag()) {
|
|
|
|
|
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
|
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
|
|
|
|
|
MS_LOG(WARNING) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
|
|
|
|
|
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
|
|
|
|
|
<< distribute_operator->inputs_tensor_info().size();
|
|
|
|
|
}
|
|
|
|
|
// insert allgather operator between shard parameter and cnode
|
|
|
|
|
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter);
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString();
|
|
|
|
|
if (insert_flag) {
|
|
|
|
|
auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph());
|
|
|
|
|
if (next_cnode.first) {
|
|
|
|
|
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
|
|
|
|
|
<< GetPrimName(cnode);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// insert allgather operator between shard parameter and cnode
|
|
|
|
|
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name);
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
|
|
|
|
|
<< GetPrimName(cnode);
|
|
|
|
|
insert_flag = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|