|
|
|
@ -66,8 +66,8 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
@ -84,6 +84,19 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
|
|
|
|
|
if (new_node_input.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag);
|
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
OperatorArgs arg_forward = op.second;
|
|
|
|
@ -158,7 +171,6 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root->manager());
|
|
|
|
|
|
|
|
|
|
AnfNodePtr local_step_param = nullptr;
|
|
|
|
|
AnfNodePtr grad_accu = nullptr;
|
|
|
|
|
std::string op_name = op.first;
|
|
|
|
|
OperatorArgs arg_forward = op.second;
|
|
|
|
@ -166,25 +178,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
|
|
|
|
|
|
|
|
|
if (grad_accumulation_step > 1) {
|
|
|
|
|
bool find_locat_step_node = false;
|
|
|
|
|
auto parameters = root->parameters();
|
|
|
|
|
for (auto ¶m : parameters) {
|
|
|
|
|
auto param_ptr = param->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_ptr);
|
|
|
|
|
if (param_ptr->name() == LOCAL_STEP) {
|
|
|
|
|
auto param_users = root->manager()->node_users()[param];
|
|
|
|
|
for (auto &user : param_users) {
|
|
|
|
|
if (AnfNodeIsPrimitive(user.first, ASSIGN)) {
|
|
|
|
|
find_locat_step_node = true;
|
|
|
|
|
local_step_param = user.first;
|
|
|
|
|
MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool find_grad_accu_node = false;
|
|
|
|
|
for (auto ¶m : parameters) {
|
|
|
|
|
if (!ParameterIsCloned(param)) {
|
|
|
|
@ -202,10 +196,12 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!find_grad_accu_node) {
|
|
|
|
|
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
|
|
|
|
if (!find_locat_step_node || !find_grad_accu_node) {
|
|
|
|
|
op_name = MIRROR_OPERATOR;
|
|
|
|
|
arg_forward.first.pop_back();
|
|
|
|
|
} else if (op_name == MINI_STEP_ALL_GATHER) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -215,9 +211,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|
|
|
|
OperatorParams params = arg_forward.second;
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_node_input;
|
|
|
|
|
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
|
|
|
|
new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu};
|
|
|
|
|
MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input";
|
|
|
|
|
if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) {
|
|
|
|
|
new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
|
|
|
|
|
MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
|
|
|
|
|
} else {
|
|
|
|
|
new_node_input = {NewValueNode(pyop_instance), node};
|
|
|
|
|
}
|
|
|
|
@ -233,6 +229,10 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|
|
|
|
|
|
|
|
|
// if the op have 'group' attr, set the rank list name for the op
|
|
|
|
|
SetCommunicationOpGroupLabel(new_node_input);
|
|
|
|
|
// gradient accumulation
|
|
|
|
|
if (grad_accumulation_step > 1) {
|
|
|
|
|
SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION));
|
|
|
|
|
}
|
|
|
|
|
return new_node_input;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -285,6 +285,31 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Replace pre_node with pre_node->op
|
|
|
|
|
static CNodePtr ReplaceMirrorNode(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &pre_node,
|
|
|
|
|
const FuncGraphPtr &func_graph, const std::string &instance_name,
|
|
|
|
|
const std::string ¶m_name) {
|
|
|
|
|
// insert new node before the node
|
|
|
|
|
FuncGraphManagerPtr manager = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
ScopePtr scope = pre_node->scope();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(scope);
|
|
|
|
|
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
|
|
|
|
CNodePtr new_node = func_graph->NewCNode(node_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
|
|
|
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
|
|
|
|
new_node->set_in_forward_flag(true); // mark forward flag
|
|
|
|
|
}
|
|
|
|
|
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
|
|
|
|
|
new_node_prim->set_instance_name(instance_name);
|
|
|
|
|
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
|
|
|
|
new_node->set_scope(scope);
|
|
|
|
|
node_input[0]->set_scope(scope);
|
|
|
|
|
manager->Replace(pre_node, new_node);
|
|
|
|
|
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!IsValueNode<Primitive>(node->input(0))) {
|
|
|
|
@ -1086,29 +1111,6 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
|
|
|
|
|
return (type_id != kNumberTypeFloat32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(comm_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_node);
|
|
|
|
|
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
|
|
|
|
|
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto param = param_node->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param);
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
auto param_info = param->param_info();
|
|
|
|
|
if (!param_info) {
|
|
|
|
|
MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
int32_t fusion_type = param_info->comm_fusion();
|
|
|
|
|
attrs[FUSION] = MakeValue<int64_t>(fusion_type);
|
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
|
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
|
|
|
|
|
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
|
|
|
|
|
MS_LOG(INFO) << "Input is ValueList, skip it.";
|
|
|
|
@ -1195,7 +1197,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|
|
|
|
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
|
|
|
|
|
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
|
|
|
|
|
// add fusion flag
|
|
|
|
|
// pipeline mirror would not be set, which should be supported later
|
|
|
|
|
AddCommOpFusionType(comm_op, param_node_pair.first);
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
@ -1540,33 +1541,40 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
|
|
|
|
|
return std::make_pair(nullptr, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res,
|
|
|
|
|
const AnfNodePtr ¶meter) {
|
|
|
|
|
Operator op = CreateAllGatherOp(group);
|
|
|
|
|
static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
|
|
|
|
|
const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(res.first);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = res.first->cast<CNodePtr>();
|
|
|
|
|
auto graph = cnode->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_prim);
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
|
|
|
|
Operator op;
|
|
|
|
|
CNodePtr allgather;
|
|
|
|
|
if (grad_accumulation_step > 1) {
|
|
|
|
|
op = CreateMiniStepAllGatherOp(group);
|
|
|
|
|
auto param_name = node->cast<ParameterPtr>()->name();
|
|
|
|
|
if (cnode_prim->name() == CAST) {
|
|
|
|
|
allgather = ReplaceMirrorNode(root, op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
|
|
|
|
|
} else {
|
|
|
|
|
InsertMirrorNode(root, op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
|
|
|
|
|
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
op = CreateAllGatherOp(group);
|
|
|
|
|
if (cnode_prim->name() == CAST) {
|
|
|
|
|
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
|
|
|
|
} else {
|
|
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
|
|
|
|
InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
|
|
|
|
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(allgather);
|
|
|
|
|
}
|
|
|
|
|
// add fusion flag
|
|
|
|
|
AddCommOpFusionType(allgather, parameter);
|
|
|
|
|
AddCommOpFusionType(allgather, node);
|
|
|
|
|
// add gradients mean
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
|
|
|
|
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
|
|
|
|
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
|
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
|
AddCommOpMeanFlag(allgather);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
|
|
|
@ -1589,7 +1597,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|
|
|
|
<< distribute_operator->inputs_tensor_info().size();
|
|
|
|
|
}
|
|
|
|
|
// insert allgather operator between shard parameter and cnode
|
|
|
|
|
InsertAllGatherOp(opt_shard_group, param_pair, parameter);
|
|
|
|
|
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter);
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1734,12 +1742,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|
|
|
|
|
|
|
|
|
if (found_be_cloned_parameter) {
|
|
|
|
|
// set the shape and tensor layout for cloned parameter
|
|
|
|
|
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
|
|
|
|
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
|
|
|
|
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
|
|
|
|
if (param_name.find(ACCU_GRADS) != std::string::npos) {
|
|
|
|
|
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
|
|
|
|
|
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel_shape);
|
|
|
|
|
cloned_abstract->set_shape(parallel_shape);
|
|
|
|
|
} else {
|
|
|
|
|
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
|
|
|
|
|
}
|
|
|
|
|
cloned_parameter_node->set_abstract(cloned_abstract);
|
|
|
|
|
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
|
|
|
|
|
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
|
|
|
|
|