!7946 fix auto parallel optimizer weight shard

Merge pull request !7946 from gziyan/fix_auto_optim_shard
pull/7946/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e909b9077c

@ -164,8 +164,6 @@ class OperatorInfo {
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
int32_t stage_id() const { return stage_id_; }
void set_opt_shard_flag(bool flag) { opt_shard_flag_ = flag; }
bool opt_shard_flag() { return opt_shard_flag_; }
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
// Key for user data.
@ -269,7 +267,6 @@ class OperatorInfo {
private:
OperatorCostPtr operator_cost_;
std::vector<TypePtr> outputs_type_;
bool opt_shard_flag_ = false;
};
Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);

@ -808,11 +808,17 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
}
}
// Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false);
} else if (node->isa<Parameter>()) {
return std::make_pair(node, false);
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, false);
} else {
return std::make_pair(node, false);
}
} else if (node->isa<ValueNode>()) {
if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
@ -820,7 +826,12 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
<< param_v.size();
}
return std::make_pair(node, true);
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, true);
} else {
return std::make_pair(node, true);
}
}
return std::make_pair(nullptr, false);
} else {
@ -1002,7 +1013,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
MirrorOps mirror_ops = distribute_operator->mirror_ops();
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
// insert mirror op
if (!mirror_ops.empty() && !distribute_operator->opt_shard_flag()) {
if (!mirror_ops.empty()) {
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
InsertMirrorOps(mirror_ops, node);
}
@ -1374,39 +1385,51 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
return std::make_pair(nullptr, 0);
}
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index) {
MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(cnode);
void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr &parameter) {
Operator op = CreateAllGatherOp(group);
MS_EXCEPTION_IF_NULL(res.first);
MS_EXCEPTION_IF_NULL(parameter);
std::vector<Group> dev_group;
// create communication group for allgather operator
if (distribute_operator->CreateGroupByTensorMap(tensor_layout->origin_tensor_map().array(), &dev_group) ==
Status::SUCCESS &&
!dev_group.empty()) {
// set optimizer shard split flag to avoid inserting mirror_ops
distribute_operator->set_opt_shard_flag(true);
// insert allgather operator between shard parameter and cnode
Operator op = CreateAllGatherOp(dev_group[0].name());
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
InsertNode(op, cnode, index, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
// set communication group in tensor layout for checkpoint saving
tensor_layout->set_opt_shard_group(dev_group[0].name());
// add fusion flag
auto allgather = cnode->input(index)->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
// enable fusion flag later when it's supported in backend
attrs["fusion"] = MakeValue(0);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString();
} else {
MS_LOG(ERROR) << "Parallel optimizer applied on " << parameter->ToString() << "failed!";
auto cnode = res.first->cast<CNodePtr>();
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
// add fusion flag
auto allgather = cnode->input(res.second)->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
// enable fusion flag later when it's supported in backend
attrs["fusion"] = MakeValue(0);
prim->SetAttrs(attrs);
}
void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
const std::string &opt_shard_group) {
if (opt_shard_group.empty()) {
return;
}
FuncGraphManagerPtr manager = root->manager();
MS_EXCEPTION_IF_NULL(manager);
auto param_sub_set = manager->node_users()[parameter];
for (auto &param_pair : param_sub_set) {
auto cnode = param_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->in_forward_flag()) {
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
<< distribute_operator->inputs_tensor_info().size();
}
// insert allgather operator between shard parameter and cnode
InsertAllGatherOp(opt_shard_group, param_pair, parameter);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString();
}
}
}
void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res) {
// When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res) {
MS_EXCEPTION_IF_NULL(parameter);
AbstractBasePtr abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
@ -1417,26 +1440,40 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
}
if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is "
<< distribute_operator->inputs_tensor_info().size();
}
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
Shape slice_shape = tensor_layout.slice_shape().array();
std::string opt_shard_group;
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
Shape slice_shape = tensor_layout.slice_shape().array();
if (enable_parallel_optimizer) {
if (!ParameterRequireGrad(parameter)) {
// only trainable parameters need parallel optimizer
MS_LOG(INFO) << "Parallel optimizer is no need for " << parameter->ToString();
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
// get a totally shard tensor slice shape if the weight is repeated on devices
// and the shape of the first dimension could be divided
// apply parallel optimizer on parameters
ApplyParallelOptOnParam(&tensor_layout, distribute_operator, cnode, parameter, IntToSize(res.second));
// create communication group for allgather operator
slice_shape = tensor_layout.opt_shard_slice_shape();
std::vector<Group> dev_group;
if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) ==
Status::SUCCESS &&
!dev_group.empty()) {
opt_shard_group = dev_group[0].name();
// set communication group in tensor layout for checkpoint saving
tensor_layout.set_opt_shard_group(opt_shard_group);
MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString()
<< " success.";
} else {
MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
}
} else {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions.";
}
}
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
@ -1451,6 +1488,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr);
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
return opt_shard_group;
}
void CoverSliceShape(const FuncGraphPtr &root) {
@ -1460,14 +1498,18 @@ void CoverSliceShape(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(parameter->Shape());
auto iter = g_RefMap.find(parameter);
if (iter != g_RefMap.end()) {
SetParallelShape(parameter, g_RefMap[parameter]);
std::string group = SetParallelShape(parameter, g_RefMap[parameter]);
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
ApplyParallelOptOnParam(root, parameter, group);
continue;
}
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
if (res.first == nullptr) {
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
} else {
SetParallelShape(parameter, res);
std::string group = SetParallelShape(parameter, res);
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
ApplyParallelOptOnParam(root, parameter, group);
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
}
}

@ -109,7 +109,7 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node);
std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &parameter);
// Set distribute shape for parameters abstract
void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res);
std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res);
// change parameters'shape in resource
void CoverSliceShape(const FuncGraphPtr &root);

@ -418,9 +418,9 @@ class Parameter(MetaTensor_):
if _is_role_worker():
data = self.init_mode.to_tensor(0, [1])
else:
data = self.init_mode.to_tensor(slice_index, layout[2])
data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
else:
data = self.init_mode.to_tensor(slice_index, layout[2])
data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
else:
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
if _is_role_worker():

@ -16,6 +16,7 @@
import numpy as np
from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import check_type, check_typename
@ -409,7 +410,7 @@ class MetaTensor(MetaTensor_):
self.init = init
MetaTensor_.__init__(self, dtype, shape)
def to_tensor(self, slice_index=None, shape=None):
def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):
"""
Get the tensor format data of this MetaTensor.
@ -418,6 +419,8 @@ class MetaTensor(MetaTensor_):
It is used when initialize a slice of a parameter, it guarantees that devices
using the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode
to get one shard of a parameter's slice.
"""
if self.init is None:
raise TypeError("to_dense must be set MetaTensor.init, init can't be None")
@ -453,7 +456,12 @@ class MetaTensor(MetaTensor_):
with seed_context(self.init):
self.init(arr)
return Tensor(arr, dtype=self.dtype)
data = np.array(arr)
if opt_shard_group:
rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group)
data = np.split(data, size)[rank]
return Tensor(data, dtype=self.dtype)
def _vm_compare(*args):

@ -16,7 +16,7 @@
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication import GlobalComm
_allgather_cell = None
@ -26,10 +26,10 @@ class AllGatherCell(Cell):
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
"""
def __init__(self):
def __init__(self, group):
super(AllGatherCell, self).__init__(auto_prefix=False)
self.allgather = AllGather()
self.allgather = AllGather(group)
def construct(self, x):
x = self.allgather(x)
@ -58,13 +58,16 @@ class SaveOptShardCkptCell(Cell):
return x
def get_allgather_cell(group):
def get_allgather_cell(group, need_merge_twice=False):
"""Get AllGatherCell object."""
global _allgather_cell
if group:
if need_merge_twice:
_allgather_cell = SaveOptShardCkptCell(group)
else:
_allgather_cell = AllGatherCell()
if group:
_allgather_cell = AllGatherCell(group)
else:
_allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
return _allgather_cell

@ -456,13 +456,18 @@ def _get_merged_param_data(net, param_name, param_data):
# while any dim is not equal to -1, means param is split and needs to be merged
# pipeline parallel need to be supported here later
for dim in tensor_map:
if dim != -1 or opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group)
if dim != -1:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
param_data = allgather_net(param_data)
if field_size:
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
return _reshape_param_data(param_data, dev_mat, tensor_map)
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
param_data = allgather_net(param_data)
return param_data

@ -77,8 +77,9 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
train_network = TrainOneStepCell(net, optimizer)
train_network.set_auto_parallel()
train_network.set_train()
_executor.compile(train_network, inputs, label)
_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
context.reset_auto_parallel_context()
return train_network
def test_auto_parallel_momentum_1():
@ -93,7 +94,11 @@ def test_auto_parallel_momentum_2():
def test_auto_parallel_momentum_3():
# hybrid parallel case
# weight1 could not be shard and weight2 is repeated
auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
param_dict = train_network.parameter_layout_dict
# validate opt_shard_group
assert not param_dict["weight1"][5]
assert param_dict["weight2"][5].startswith("4")
def test_auto_parallel_momentum_4():

Loading…
Cancel
Save