|
|
@ -808,11 +808,17 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Only used for InsertMirrorOps
|
|
|
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
|
|
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
|
|
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
|
|
|
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
} else if (node->isa<Parameter>()) {
|
|
|
|
} else if (node->isa<Parameter>()) {
|
|
|
|
return std::make_pair(node, false);
|
|
|
|
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
|
|
|
|
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
|
|
|
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return std::make_pair(node, false);
|
|
|
|
|
|
|
|
}
|
|
|
|
} else if (node->isa<ValueNode>()) {
|
|
|
|
} else if (node->isa<ValueNode>()) {
|
|
|
|
if (IsValueNode<RefKey>(node)) {
|
|
|
|
if (IsValueNode<RefKey>(node)) {
|
|
|
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
|
|
|
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
|
|
|
@ -820,7 +826,12 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|
|
|
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
|
|
|
|
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
|
|
|
|
<< param_v.size();
|
|
|
|
<< param_v.size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return std::make_pair(node, true);
|
|
|
|
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
|
|
|
|
|
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
|
|
|
|
|
|
|
|
return std::make_pair(nullptr, true);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return std::make_pair(node, true);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -1002,7 +1013,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
|
|
|
|
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
|
|
|
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
|
|
|
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
|
|
|
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
|
|
|
// insert mirror op
|
|
|
|
// insert mirror op
|
|
|
|
if (!mirror_ops.empty() && !distribute_operator->opt_shard_flag()) {
|
|
|
|
if (!mirror_ops.empty()) {
|
|
|
|
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
|
|
|
|
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
|
|
|
|
InsertMirrorOps(mirror_ops, node);
|
|
|
|
InsertMirrorOps(mirror_ops, node);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -1374,39 +1385,51 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
|
|
|
|
return std::make_pair(nullptr, 0);
|
|
|
|
return std::make_pair(nullptr, 0);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
|
|
|
|
void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr ¶meter) {
|
|
|
|
const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index) {
|
|
|
|
Operator op = CreateAllGatherOp(group);
|
|
|
|
MS_EXCEPTION_IF_NULL(distribute_operator);
|
|
|
|
MS_EXCEPTION_IF_NULL(res.first);
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
std::vector<Group> dev_group;
|
|
|
|
auto cnode = res.first->cast<CNodePtr>();
|
|
|
|
// create communication group for allgather operator
|
|
|
|
auto graph = cnode->func_graph();
|
|
|
|
if (distribute_operator->CreateGroupByTensorMap(tensor_layout->origin_tensor_map().array(), &dev_group) ==
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
Status::SUCCESS &&
|
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
|
|
|
!dev_group.empty()) {
|
|
|
|
// add fusion flag
|
|
|
|
// set optimizer shard split flag to avoid inserting mirror_ops
|
|
|
|
auto allgather = cnode->input(res.second)->cast<CNodePtr>();
|
|
|
|
distribute_operator->set_opt_shard_flag(true);
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
|
|
|
// insert allgather operator between shard parameter and cnode
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
Operator op = CreateAllGatherOp(dev_group[0].name());
|
|
|
|
// enable fusion flag later when it's supported in backend
|
|
|
|
auto graph = cnode->func_graph();
|
|
|
|
attrs["fusion"] = MakeValue(0);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
InsertNode(op, cnode, index, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
|
|
|
}
|
|
|
|
// set communication group in tensor layout for checkpoint saving
|
|
|
|
|
|
|
|
tensor_layout->set_opt_shard_group(dev_group[0].name());
|
|
|
|
void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
|
|
|
// add fusion flag
|
|
|
|
const std::string &opt_shard_group) {
|
|
|
|
auto allgather = cnode->input(index)->cast<CNodePtr>();
|
|
|
|
if (opt_shard_group.empty()) {
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
|
|
|
return;
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
}
|
|
|
|
// enable fusion flag later when it's supported in backend
|
|
|
|
FuncGraphManagerPtr manager = root->manager();
|
|
|
|
attrs["fusion"] = MakeValue(0);
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
auto param_sub_set = manager->node_users()[parameter];
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString();
|
|
|
|
for (auto ¶m_pair : param_sub_set) {
|
|
|
|
} else {
|
|
|
|
auto cnode = param_pair.first->cast<CNodePtr>();
|
|
|
|
MS_LOG(ERROR) << "Parallel optimizer applied on " << parameter->ToString() << "failed!";
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
|
|
|
if (cnode->in_forward_flag()) {
|
|
|
|
|
|
|
|
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
|
|
|
|
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
|
|
|
|
|
|
|
|
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
|
|
|
|
|
|
|
|
<< distribute_operator->inputs_tensor_info().size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// insert allgather operator between shard parameter and cnode
|
|
|
|
|
|
|
|
InsertAllGatherOp(opt_shard_group, param_pair, parameter);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int> &res) {
|
|
|
|
// When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
|
|
|
|
|
|
|
|
std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int> &res) {
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
AbstractBasePtr abstract = parameter->abstract();
|
|
|
|
AbstractBasePtr abstract = parameter->abstract();
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
@ -1417,26 +1440,40 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
|
|
|
|
if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
|
|
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is "
|
|
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is "
|
|
|
|
<< distribute_operator->inputs_tensor_info().size();
|
|
|
|
<< distribute_operator->inputs_tensor_info().size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
|
|
|
|
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
|
|
|
|
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
|
|
|
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
|
|
|
|
|
|
|
Shape slice_shape = tensor_layout.slice_shape().array();
|
|
|
|
|
|
|
|
std::string opt_shard_group;
|
|
|
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
|
|
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
|
|
|
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
|
|
|
|
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
|
|
|
|
Shape slice_shape = tensor_layout.slice_shape().array();
|
|
|
|
|
|
|
|
if (enable_parallel_optimizer) {
|
|
|
|
if (enable_parallel_optimizer) {
|
|
|
|
if (!ParameterRequireGrad(parameter)) {
|
|
|
|
if (!ParameterRequireGrad(parameter)) {
|
|
|
|
// only trainable parameters need parallel optimizer
|
|
|
|
// only trainable parameters need parallel optimizer
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is no need for " << parameter->ToString();
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
|
|
|
|
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
|
|
|
|
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
|
|
|
|
// get a totally shard tensor slice shape if the weight is repeated on devices
|
|
|
|
// get a totally shard tensor slice shape if the weight is repeated on devices
|
|
|
|
// and the shape of the first dimension could be divided
|
|
|
|
// and the shape of the first dimension could be divided
|
|
|
|
// apply parallel optimizer on parameters
|
|
|
|
// apply parallel optimizer on parameters
|
|
|
|
ApplyParallelOptOnParam(&tensor_layout, distribute_operator, cnode, parameter, IntToSize(res.second));
|
|
|
|
// create communication group for allgather operator
|
|
|
|
slice_shape = tensor_layout.opt_shard_slice_shape();
|
|
|
|
slice_shape = tensor_layout.opt_shard_slice_shape();
|
|
|
|
|
|
|
|
std::vector<Group> dev_group;
|
|
|
|
|
|
|
|
if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) ==
|
|
|
|
|
|
|
|
Status::SUCCESS &&
|
|
|
|
|
|
|
|
!dev_group.empty()) {
|
|
|
|
|
|
|
|
opt_shard_group = dev_group[0].name();
|
|
|
|
|
|
|
|
// set communication group in tensor layout for checkpoint saving
|
|
|
|
|
|
|
|
tensor_layout.set_opt_shard_group(opt_shard_group);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString()
|
|
|
|
|
|
|
|
<< " success.";
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions.";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
|
|
|
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
|
|
@ -1451,6 +1488,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|
|
|
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
|
|
|
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
|
|
|
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
|
|
|
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
|
|
|
|
|
|
|
return opt_shard_group;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CoverSliceShape(const FuncGraphPtr &root) {
|
|
|
|
void CoverSliceShape(const FuncGraphPtr &root) {
|
|
|
@ -1460,14 +1498,18 @@ void CoverSliceShape(const FuncGraphPtr &root) {
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter->Shape());
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter->Shape());
|
|
|
|
auto iter = g_RefMap.find(parameter);
|
|
|
|
auto iter = g_RefMap.find(parameter);
|
|
|
|
if (iter != g_RefMap.end()) {
|
|
|
|
if (iter != g_RefMap.end()) {
|
|
|
|
SetParallelShape(parameter, g_RefMap[parameter]);
|
|
|
|
std::string group = SetParallelShape(parameter, g_RefMap[parameter]);
|
|
|
|
|
|
|
|
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
|
|
|
|
|
|
|
|
ApplyParallelOptOnParam(root, parameter, group);
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
|
|
|
|
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
|
|
|
|
if (res.first == nullptr) {
|
|
|
|
if (res.first == nullptr) {
|
|
|
|
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
|
|
|
|
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
SetParallelShape(parameter, res);
|
|
|
|
std::string group = SetParallelShape(parameter, res);
|
|
|
|
|
|
|
|
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
|
|
|
|
|
|
|
|
ApplyParallelOptOnParam(root, parameter, group);
|
|
|
|
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
|
|
|
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|