|
|
|
@ -97,6 +97,27 @@ void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool acc
|
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
|
|
|
|
|
if (new_node_input.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
|
|
|
|
|
auto anf_node = node->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_node);
|
|
|
|
|
auto node_attrs = prim_node->attrs();
|
|
|
|
|
if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
|
|
|
|
|
attrs[RECOMPUTE] = MakeValue<bool>(false);
|
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
|
MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
OperatorArgs arg_forward = op.second;
|
|
|
|
@ -353,6 +374,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
|
|
|
|
|
std::string instance_name_base = FORWARD_OP;
|
|
|
|
|
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
|
|
|
|
|
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
|
|
|
|
|
SetAllReduceRecomputeFlag(forward_input, node_to_insert);
|
|
|
|
|
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
|
|
|
|
|
MS_EXCEPTION_IF_NULL(forward_node);
|
|
|
|
|
ScopePtr scope = node->scope();
|
|
|
|
@ -1165,7 +1187,14 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|
|
|
|
|
|
|
|
|
// not a RefKey
|
|
|
|
|
if (!param_node_pair.second) {
|
|
|
|
|
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
|
|
|
|
std::string mirror_op_name;
|
|
|
|
|
if (grad_accumulation_step > 1) {
|
|
|
|
|
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
|
|
|
|
|
} else {
|
|
|
|
|
mirror_op_name = MIRROR_OPERATOR;
|
|
|
|
|
}
|
|
|
|
|
auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph);
|
|
|
|
|
// if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
|
|
|
|
|
if (next_cnode.first) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(next_cnode.second);
|
|
|
|
@ -1743,6 +1772,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|
|
|
|
if (found_be_cloned_parameter) {
|
|
|
|
|
// set the shape and tensor layout for cloned parameter
|
|
|
|
|
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
|
|
|
|
if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
|
|
|
@ -3298,6 +3331,97 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
|
|
|
|
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
|
|
|
|
if (tensor_layout == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dev_mat_shape = tensor_layout->device_arrangement().array();
|
|
|
|
|
auto tensor_map = tensor_layout->tensor_map().array();
|
|
|
|
|
int64_t rank = g_device_manager->global_rank();
|
|
|
|
|
RankList rank_list = g_device_manager->GetDeviceListInThisStage();
|
|
|
|
|
DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
|
|
|
|
|
RankList group_devices;
|
|
|
|
|
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
|
|
|
|
MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (group_devices.size() == 1) {
|
|
|
|
|
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name) {
|
|
|
|
|
for (auto ¶meter : parameters) {
|
|
|
|
|
auto param_ptr = parameter->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_ptr);
|
|
|
|
|
if (param_ptr->name() == name) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) {
|
|
|
|
|
return parameter;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
|
|
|
|
|
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
|
|
|
|
|
auto cnode = node_user.first->cast<CNodePtr>();
|
|
|
|
|
auto prim = GetCNodePrimitive(cnode);
|
|
|
|
|
if (prim == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
OperatorAttrs attrs;
|
|
|
|
|
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
|
|
|
|
|
auto value_node = NewValueNode(py_instance);
|
|
|
|
|
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
|
|
|
|
|
auto graph = cnode->func_graph();
|
|
|
|
|
auto virtual_node = graph->NewCNode(virtual_node_input);
|
|
|
|
|
manager->SetEdge(cnode, node_user.second, virtual_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
|
|
|
|
if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto parameters = root->parameters();
|
|
|
|
|
auto node_users_map = root->manager()->node_users();
|
|
|
|
|
for (auto ¶meter : parameters) {
|
|
|
|
|
auto param_ptr = parameter->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_ptr);
|
|
|
|
|
|
|
|
|
|
if (!IsFullySplitParameter(param_ptr)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
|
|
|
|
|
if (!accu_parameter) {
|
|
|
|
|
continue; // some parameters no need to handle, such as itself or lr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto node_users = node_users_map[parameter];
|
|
|
|
|
for (auto &user : node_users) {
|
|
|
|
|
auto node = user.first;
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (!cnode->in_forward_flag()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
|
|
|
|
|
MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
|
|
|
|
|
break; // only need to insert once, if the parameter has many users
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
|
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
|
|
|
|
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
|
|
|
@ -3390,6 +3514,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|
|
|
|
MS_LOG(EXCEPTION) << "Save group info failed";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
|
|
|
|
|
HandleFullySplitParameters(root);
|
|
|
|
|
|
|
|
|
|
DumpGraph(root, std::string(STEP_PARALLEL_END));
|
|
|
|
|
|
|
|
|
|
// step parallel only run once
|
|
|
|
|