|
|
|
@ -128,6 +128,137 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
|
|
|
|
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter_node);
|
|
|
|
|
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
|
|
|
|
|
|
|
|
|
// find the clone parameter
|
|
|
|
|
if (!cloned_parameter->has_default()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto param_value = cloned_parameter->param_info();
|
|
|
|
|
if (param_value == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
bool cloned = param_value->cloned();
|
|
|
|
|
if (!cloned) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
|
|
|
|
|
const std::string &instance_name, const std::string &weight_name) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root->manager());
|
|
|
|
|
|
|
|
|
|
AnfNodePtr local_step_param = nullptr;
|
|
|
|
|
AnfNodePtr grad_accu = nullptr;
|
|
|
|
|
std::string op_name = op.first;
|
|
|
|
|
OperatorArgs arg_forward = op.second;
|
|
|
|
|
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
|
|
|
|
|
|
|
|
|
if (grad_accumulation_step > 1) {
|
|
|
|
|
bool find_locat_step_node = false;
|
|
|
|
|
auto parameters = root->parameters();
|
|
|
|
|
for (auto ¶m : parameters) {
|
|
|
|
|
auto param_ptr = param->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_ptr);
|
|
|
|
|
if (param_ptr->name() == LOCAL_STEP) {
|
|
|
|
|
auto param_users = root->manager()->node_users()[param];
|
|
|
|
|
for (auto &user : param_users) {
|
|
|
|
|
if (AnfNodeIsPrimitive(user.first, ASSIGN)) {
|
|
|
|
|
find_locat_step_node = true;
|
|
|
|
|
local_step_param = user.first;
|
|
|
|
|
MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool find_grad_accu_node = false;
|
|
|
|
|
for (auto ¶m : parameters) {
|
|
|
|
|
if (!ParameterIsCloned(param)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto param_ptr = param->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_ptr);
|
|
|
|
|
if (param_ptr->name().find(weight_name) != std::string::npos &&
|
|
|
|
|
param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
|
|
|
|
|
find_grad_accu_node = true;
|
|
|
|
|
grad_accu = param;
|
|
|
|
|
MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
|
|
|
|
if (!find_locat_step_node || !find_grad_accu_node) {
|
|
|
|
|
op_name = MIRROR_OPERATOR;
|
|
|
|
|
arg_forward.first.pop_back();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pyop_instance);
|
|
|
|
|
OperatorParams params = arg_forward.second;
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_node_input;
|
|
|
|
|
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
|
|
|
|
new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu};
|
|
|
|
|
MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input";
|
|
|
|
|
} else {
|
|
|
|
|
new_node_input = {NewValueNode(pyop_instance), node};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!params.empty()) {
|
|
|
|
|
for (auto ¶m : params) {
|
|
|
|
|
AnfNodePtr val = NewValueNode(param.first.second);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(val);
|
|
|
|
|
int64_t position = param.second;
|
|
|
|
|
(void)new_node_input.insert(new_node_input.begin() + position, val);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if the op have 'group' attr, set the rank list name for the op
|
|
|
|
|
SetCommunicationOpGroupLabel(new_node_input);
|
|
|
|
|
return new_node_input;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index,
|
|
|
|
|
const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name,
|
|
|
|
|
const std::string ¶m_name) {
|
|
|
|
|
// insert new node before the node
|
|
|
|
|
FuncGraphManagerPtr manager = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
ScopePtr scope = node->scope();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(scope);
|
|
|
|
|
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
|
|
|
|
CNodePtr new_node = func_graph->NewCNode(node_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
|
|
|
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
|
|
|
|
new_node->set_in_forward_flag(true); // mark forward flag
|
|
|
|
|
}
|
|
|
|
|
auto new_node_value = node_input[0]->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_node_value);
|
|
|
|
|
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
|
|
|
|
|
new_node_prim->set_instance_name(instance_name);
|
|
|
|
|
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
|
|
|
|
new_node->set_scope(scope);
|
|
|
|
|
node_input[0]->set_scope(scope);
|
|
|
|
|
manager->SetEdge(node, SizeToLong(index), new_node);
|
|
|
|
|
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Replace pre_node with pre_node->op
|
|
|
|
|
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
|
|
|
|
|
const std::string &instance_name) {
|
|
|
|
@ -965,7 +1096,7 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par
|
|
|
|
|
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
size_t node_size = node->inputs().size();
|
|
|
|
|
FuncGraphPtr func_graph = node->func_graph();
|
|
|
|
@ -997,6 +1128,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
if (!param_node_pair.first) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
|
|
|
|
|
std::string param_name;
|
|
|
|
|
if (param_ptr != nullptr) {
|
|
|
|
|
param_name = param_ptr->name();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// not a RefKey
|
|
|
|
|
if (!param_node_pair.second) {
|
|
|
|
|
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
|
|
|
|
@ -1028,7 +1166,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
AnfNodePtr pre_node = cnode->input(1);
|
|
|
|
|
InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name);
|
|
|
|
|
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
|
|
|
|
|
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
|
|
|
|
|
// add fusion flag
|
|
|
|
|
// pipeline mirror would not be set, which should be supported later
|
|
|
|
@ -1037,7 +1175,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
} else {
|
|
|
|
|
for (auto &op : backward_op) {
|
|
|
|
|
AnfNodePtr pre_node = node->input(index);
|
|
|
|
|
InsertNode(op, node, index, pre_node, func_graph, instance_name);
|
|
|
|
|
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
|
|
|
|
|
auto comm_op = node->input(index)->cast<CNodePtr>();
|
|
|
|
|
// add fusion flag
|
|
|
|
|
// pipeline mirror would not be set, which should be supported later
|
|
|
|
@ -1047,7 +1185,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
|
|
|
|
void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
|
|
|
|
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(distribute_operator);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
@ -1061,7 +1199,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
|
|
|
|
|
// insert mirror op
|
|
|
|
|
if (!mirror_ops.empty()) {
|
|
|
|
|
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
|
|
|
|
|
InsertMirrorOps(mirror_ops, node);
|
|
|
|
|
InsertMirrorOps(root, mirror_ops, node);
|
|
|
|
|
}
|
|
|
|
|
// insert virtual div op
|
|
|
|
|
if (!virtual_div_op.empty() && is_loss_cnode) {
|
|
|
|
@ -1519,28 +1657,6 @@ void CoverSliceShape(const FuncGraphPtr &root) {
|
|
|
|
|
g_RefMap.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter_node);
|
|
|
|
|
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
|
|
|
|
|
|
|
|
|
// find the clone parameter
|
|
|
|
|
if (!cloned_parameter->has_default()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto param_value = cloned_parameter->param_info();
|
|
|
|
|
if (param_value == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
bool cloned = param_value->cloned();
|
|
|
|
|
if (!cloned) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root);
|
|
|
|
|
for (auto &cloned_parameter_node : root->parameters()) {
|
|
|
|
@ -2459,7 +2575,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|
|
|
|
|
|
|
|
|
// insert backward ops
|
|
|
|
|
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
|
|
|
|
|
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
|
|
|
|
|
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HandleSpecialNode(distribute_operator, cnode);
|
|
|
|
|