!7946 fix auto parallel optimizer weight shard

Merge pull request !7946 from gziyan/fix_auto_optim_shard
pull/7946/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e909b9077c

@ -164,8 +164,6 @@ class OperatorInfo {
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
int32_t stage_id() const { return stage_id_; } int32_t stage_id() const { return stage_id_; }
void set_opt_shard_flag(bool flag) { opt_shard_flag_ = flag; }
bool opt_shard_flag() { return opt_shard_flag_; }
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
// Key for user data. // Key for user data.
@ -269,7 +267,6 @@ class OperatorInfo {
private: private:
OperatorCostPtr operator_cost_; OperatorCostPtr operator_cost_;
std::vector<TypePtr> outputs_type_; std::vector<TypePtr> outputs_type_;
bool opt_shard_flag_ = false;
}; };
Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);

@ -808,11 +808,17 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
} }
} }
// Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) { if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false); return std::make_pair(nullptr, false);
} else if (node->isa<Parameter>()) { } else if (node->isa<Parameter>()) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, false);
} else {
return std::make_pair(node, false); return std::make_pair(node, false);
}
} else if (node->isa<ValueNode>()) { } else if (node->isa<ValueNode>()) {
if (IsValueNode<RefKey>(node)) { if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
@ -820,8 +826,13 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
<< param_v.size(); << param_v.size();
} }
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, true);
} else {
return std::make_pair(node, true); return std::make_pair(node, true);
} }
}
return std::make_pair(nullptr, false); return std::make_pair(nullptr, false);
} else { } else {
CNodePtr cnode = node->cast<CNodePtr>(); CNodePtr cnode = node->cast<CNodePtr>();
@ -1002,7 +1013,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
MirrorOps mirror_ops = distribute_operator->mirror_ops(); MirrorOps mirror_ops = distribute_operator->mirror_ops();
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
// insert mirror op // insert mirror op
if (!mirror_ops.empty() && !distribute_operator->opt_shard_flag()) { if (!mirror_ops.empty()) {
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
InsertMirrorOps(mirror_ops, node); InsertMirrorOps(mirror_ops, node);
} }
@ -1374,39 +1385,51 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
return std::make_pair(nullptr, 0); return std::make_pair(nullptr, 0);
} }
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr &parameter) {
const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index) { Operator op = CreateAllGatherOp(group);
MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(res.first);
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
std::vector<Group> dev_group; auto cnode = res.first->cast<CNodePtr>();
// create communication group for allgather operator
if (distribute_operator->CreateGroupByTensorMap(tensor_layout->origin_tensor_map().array(), &dev_group) ==
Status::SUCCESS &&
!dev_group.empty()) {
// set optimizer shard split flag to avoid inserting mirror_ops
distribute_operator->set_opt_shard_flag(true);
// insert allgather operator between shard parameter and cnode
Operator op = CreateAllGatherOp(dev_group[0].name());
auto graph = cnode->func_graph(); auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
InsertNode(op, cnode, index, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
// set communication group in tensor layout for checkpoint saving
tensor_layout->set_opt_shard_group(dev_group[0].name());
// add fusion flag // add fusion flag
auto allgather = cnode->input(index)->cast<CNodePtr>(); auto allgather = cnode->input(res.second)->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs(); auto attrs = prim->attrs();
// enable fusion flag later when it's supported in backend // enable fusion flag later when it's supported in backend
attrs["fusion"] = MakeValue(0); attrs["fusion"] = MakeValue(0);
prim->SetAttrs(attrs); prim->SetAttrs(attrs);
MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString(); }
} else {
MS_LOG(ERROR) << "Parallel optimizer applied on " << parameter->ToString() << "failed!"; void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
const std::string &opt_shard_group) {
if (opt_shard_group.empty()) {
return;
}
FuncGraphManagerPtr manager = root->manager();
MS_EXCEPTION_IF_NULL(manager);
auto param_sub_set = manager->node_users()[parameter];
for (auto &param_pair : param_sub_set) {
auto cnode = param_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->in_forward_flag()) {
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
<< distribute_operator->inputs_tensor_info().size();
}
// insert allgather operator between shard parameter and cnode
InsertAllGatherOp(opt_shard_group, param_pair, parameter);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString();
}
} }
} }
void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res) { // When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res) {
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
AbstractBasePtr abstract = parameter->abstract(); AbstractBasePtr abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract); MS_EXCEPTION_IF_NULL(abstract);
@ -1417,26 +1440,40 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
if (distribute_operator == nullptr) { if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
} }
if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is " MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is "
<< distribute_operator->inputs_tensor_info().size(); << distribute_operator->inputs_tensor_info().size();
} }
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
Shape slice_shape = tensor_layout.slice_shape().array();
std::string opt_shard_group;
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
Shape slice_shape = tensor_layout.slice_shape().array();
if (enable_parallel_optimizer) { if (enable_parallel_optimizer) {
if (!ParameterRequireGrad(parameter)) { if (!ParameterRequireGrad(parameter)) {
// only trainable parameters need parallel optimizer // only trainable parameters need parallel optimizer
MS_LOG(INFO) << "Parallel optimizer is no need for " << parameter->ToString(); MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
// get a totally shard tensor slice shape if the weight is repeated on devices // get a totally shard tensor slice shape if the weight is repeated on devices
// and the shape of the first dimension could be divided // and the shape of the first dimension could be divided
// apply parallel optimizer on parameters // apply parallel optimizer on parameters
ApplyParallelOptOnParam(&tensor_layout, distribute_operator, cnode, parameter, IntToSize(res.second)); // create communication group for allgather operator
slice_shape = tensor_layout.opt_shard_slice_shape(); slice_shape = tensor_layout.opt_shard_slice_shape();
std::vector<Group> dev_group;
if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) ==
Status::SUCCESS &&
!dev_group.empty()) {
opt_shard_group = dev_group[0].name();
// set communication group in tensor layout for checkpoint saving
tensor_layout.set_opt_shard_group(opt_shard_group);
MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString()
<< " success.";
} else {
MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
}
} else {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions.";
} }
} }
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
@ -1451,6 +1488,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr); MS_EXCEPTION_IF_NULL(parameter_ptr);
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
return opt_shard_group;
} }
void CoverSliceShape(const FuncGraphPtr &root) { void CoverSliceShape(const FuncGraphPtr &root) {
@ -1460,14 +1498,18 @@ void CoverSliceShape(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(parameter->Shape()); MS_EXCEPTION_IF_NULL(parameter->Shape());
auto iter = g_RefMap.find(parameter); auto iter = g_RefMap.find(parameter);
if (iter != g_RefMap.end()) { if (iter != g_RefMap.end()) {
SetParallelShape(parameter, g_RefMap[parameter]); std::string group = SetParallelShape(parameter, g_RefMap[parameter]);
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
ApplyParallelOptOnParam(root, parameter, group);
continue; continue;
} }
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter); std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
if (res.first == nullptr) { if (res.first == nullptr) {
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
} else { } else {
SetParallelShape(parameter, res); std::string group = SetParallelShape(parameter, res);
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
ApplyParallelOptOnParam(root, parameter, group);
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
} }
} }

@ -109,7 +109,7 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node);
std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &parameter); std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &parameter);
// Set distribute shape for parameters abstract // Set distribute shape for parameters abstract
void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res); std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int> &res);
// change parameters'shape in resource // change parameters'shape in resource
void CoverSliceShape(const FuncGraphPtr &root); void CoverSliceShape(const FuncGraphPtr &root);

@ -418,9 +418,9 @@ class Parameter(MetaTensor_):
if _is_role_worker(): if _is_role_worker():
data = self.init_mode.to_tensor(0, [1]) data = self.init_mode.to_tensor(0, [1])
else: else:
data = self.init_mode.to_tensor(slice_index, layout[2]) data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
else: else:
data = self.init_mode.to_tensor(slice_index, layout[2]) data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
else: else:
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
if _is_role_worker(): if _is_role_worker():

@ -16,6 +16,7 @@
import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_ from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import check_type, check_typename from .._checkparam import check_type, check_typename
@ -409,7 +410,7 @@ class MetaTensor(MetaTensor_):
self.init = init self.init = init
MetaTensor_.__init__(self, dtype, shape) MetaTensor_.__init__(self, dtype, shape)
def to_tensor(self, slice_index=None, shape=None): def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):
""" """
Get the tensor format data of this MetaTensor. Get the tensor format data of this MetaTensor.
@ -418,6 +419,8 @@ class MetaTensor(MetaTensor_):
It is used when initialize a slice of a parameter, it guarantees that devices It is used when initialize a slice of a parameter, it guarantees that devices
using the same slice can generate the same tensor. using the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode
to get one shard of a parameter's slice.
""" """
if self.init is None: if self.init is None:
raise TypeError("to_dense must be set MetaTensor.init, init can't be None") raise TypeError("to_dense must be set MetaTensor.init, init can't be None")
@ -453,7 +456,12 @@ class MetaTensor(MetaTensor_):
with seed_context(self.init): with seed_context(self.init):
self.init(arr) self.init(arr)
return Tensor(arr, dtype=self.dtype) data = np.array(arr)
if opt_shard_group:
rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group)
data = np.split(data, size)[rank]
return Tensor(data, dtype=self.dtype)
def _vm_compare(*args): def _vm_compare(*args):

@ -16,7 +16,7 @@
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication import GlobalComm
_allgather_cell = None _allgather_cell = None
@ -26,10 +26,10 @@ class AllGatherCell(Cell):
Allgather cell, used in model parallel scenario. Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device. To allgather the selected parameter slice from each device.
""" """
def __init__(self): def __init__(self, group):
super(AllGatherCell, self).__init__(auto_prefix=False) super(AllGatherCell, self).__init__(auto_prefix=False)
self.allgather = AllGather() self.allgather = AllGather(group)
def construct(self, x): def construct(self, x):
x = self.allgather(x) x = self.allgather(x)
@ -58,13 +58,16 @@ class SaveOptShardCkptCell(Cell):
return x return x
def get_allgather_cell(group): def get_allgather_cell(group, need_merge_twice=False):
"""Get AllGatherCell object.""" """Get AllGatherCell object."""
global _allgather_cell global _allgather_cell
if group: if need_merge_twice:
_allgather_cell = SaveOptShardCkptCell(group) _allgather_cell = SaveOptShardCkptCell(group)
else: else:
_allgather_cell = AllGatherCell() if group:
_allgather_cell = AllGatherCell(group)
else:
_allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
return _allgather_cell return _allgather_cell

@ -456,13 +456,18 @@ def _get_merged_param_data(net, param_name, param_data):
# while any dim is not equal to -1, means param is split and needs to be merged # while any dim is not equal to -1, means param is split and needs to be merged
# pipeline parallel need to be supported here later # pipeline parallel need to be supported here later
for dim in tensor_map: for dim in tensor_map:
if dim != -1 or opt_shard_group: if dim != -1:
allgather_net = get_allgather_cell(opt_shard_group) if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
param_data = allgather_net(param_data) param_data = allgather_net(param_data)
if field_size: if field_size:
return _reshape_param_data_with_weight(param_data, dev_mat, field_size) return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
return _reshape_param_data(param_data, dev_mat, tensor_map) return _reshape_param_data(param_data, dev_mat, tensor_map)
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
param_data = allgather_net(param_data)
return param_data return param_data

@ -77,8 +77,9 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
train_network = TrainOneStepCell(net, optimizer) train_network = TrainOneStepCell(net, optimizer)
train_network.set_auto_parallel() train_network.set_auto_parallel()
train_network.set_train() train_network.set_train()
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
return train_network
def test_auto_parallel_momentum_1(): def test_auto_parallel_momentum_1():
@ -93,7 +94,11 @@ def test_auto_parallel_momentum_2():
def test_auto_parallel_momentum_3(): def test_auto_parallel_momentum_3():
# hybrid parallel case # hybrid parallel case
# weight1 could not be shard and weight2 is repeated # weight1 could not be shard and weight2 is repeated
auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2))) train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
param_dict = train_network.parameter_layout_dict
# validate opt_shard_group
assert not param_dict["weight1"][5]
assert param_dict["weight2"][5].startswith("4")
def test_auto_parallel_momentum_4(): def test_auto_parallel_momentum_4():

Loading…
Cancel
Save