diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 0eff3e15f9..d3fd2a645e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -164,8 +164,6 @@ class OperatorInfo { const std::unordered_map &attrs() const { return attrs_; } void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } int32_t stage_id() const { return stage_id_; } - void set_opt_shard_flag(bool flag) { opt_shard_flag_ = flag; } - bool opt_shard_flag() { return opt_shard_flag_; } Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); // Key for user data. @@ -269,7 +267,6 @@ class OperatorInfo { private: OperatorCostPtr operator_cost_; std::vector outputs_type_; - bool opt_shard_flag_ = false; }; Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 5d1432a305..33f403616c 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -808,11 +808,17 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node } } +// Only used for InsertMirrorOps std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { if (!node->isa() && !node->isa() && !node->isa()) { return std::make_pair(nullptr, false); } else if (node->isa()) { - return std::make_pair(node, false); + auto param_ptr = node->user_data(); + if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { + return std::make_pair(nullptr, false); + } else { + return std::make_pair(node, false); + } } else if (node->isa()) { if (IsValueNode(node)) { std::vector param_v = FindParameterByRefKeyNode(node, func_graph); @@ -820,7 +826,12 @@ std::pair FindParameter(const AnfNodePtr &node, const FuncGrap MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " << param_v.size(); } - return std::make_pair(node, true); + auto param_ptr = param_v[0]->user_data(); + if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { + return std::make_pair(nullptr, true); + } else { + return std::make_pair(node, true); + } } return std::make_pair(nullptr, false); } else { @@ -1002,7 +1013,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo MirrorOps mirror_ops = distribute_operator->mirror_ops(); VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); // insert mirror op - if (!mirror_ops.empty() && !distribute_operator->opt_shard_flag()) { + if (!mirror_ops.empty()) { MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); InsertMirrorOps(mirror_ops, node); } @@ -1374,39 +1385,51 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNode return std::make_pair(nullptr, 0); } -void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, - const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(cnode); +void InsertAllGatherOp(const std::string &group, const std::pair &res, const AnfNodePtr ¶meter) { + Operator op = CreateAllGatherOp(group); + MS_EXCEPTION_IF_NULL(res.first); MS_EXCEPTION_IF_NULL(parameter); - std::vector dev_group; - // create communication group for allgather operator - if (distribute_operator->CreateGroupByTensorMap(tensor_layout->origin_tensor_map().array(), &dev_group) == - Status::SUCCESS && - !dev_group.empty()) { - // set optimizer shard split flag to avoid inserting mirror_ops - distribute_operator->set_opt_shard_flag(true); - // insert allgather operator between shard parameter and cnode - Operator op = CreateAllGatherOp(dev_group[0].name()); - auto graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(graph); - InsertNode(op, cnode, index, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); - // set communication group in tensor layout for checkpoint saving - tensor_layout->set_opt_shard_group(dev_group[0].name()); - // add fusion flag - auto allgather = cnode->input(index)->cast(); - auto prim = GetValueNode(allgather->input(0)); - auto attrs = prim->attrs(); - // enable fusion flag later when it's supported in backend - attrs["fusion"] = MakeValue(0); - prim->SetAttrs(attrs); - MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString(); - } else { - MS_LOG(ERROR) << "Parallel optimizer applied on " << parameter->ToString() << "failed!"; + auto cnode = res.first->cast(); + auto graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); + // add fusion flag + auto allgather = cnode->input(res.second)->cast(); + auto prim = GetValueNode(allgather->input(0)); + auto attrs = prim->attrs(); + // enable fusion flag later when it's supported in backend + attrs["fusion"] = MakeValue(0); + prim->SetAttrs(attrs); +} + +void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, + const std::string &opt_shard_group) { + if (opt_shard_group.empty()) { + return; + } + FuncGraphManagerPtr manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto param_sub_set = manager->node_users()[parameter]; + for (auto ¶m_pair : param_sub_set) { + auto cnode = param_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->in_forward_flag()) { + OperatorInfoPtr distribute_operator = cnode->user_data(); + if (distribute_operator == nullptr) { + MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; + } else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " + << distribute_operator->inputs_tensor_info().size(); + } + // insert allgather operator between shard parameter and cnode + InsertAllGatherOp(opt_shard_group, param_pair, parameter); + MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); + } } } -void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { +// When this function returns non-empty string, that means parallel optimizer is applied on this parameter. +std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { MS_EXCEPTION_IF_NULL(parameter); AbstractBasePtr abstract = parameter->abstract(); MS_EXCEPTION_IF_NULL(abstract); @@ -1417,26 +1440,40 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pairToString() << " 's OperatorInfoPtr is nullptr"; } - if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is " << distribute_operator->inputs_tensor_info().size(); } TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); + Shape slice_shape = tensor_layout.slice_shape().array(); + std::string opt_shard_group; MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); - Shape slice_shape = tensor_layout.slice_shape().array(); if (enable_parallel_optimizer) { if (!ParameterRequireGrad(parameter)) { // only trainable parameters need parallel optimizer - MS_LOG(INFO) << "Parallel optimizer is no need for " << parameter->ToString(); + MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { // get a totally shard tensor slice shape if the weight is repeated on devices // and the shape of the first dimension could be divided // apply parallel optimizer on parameters - ApplyParallelOptOnParam(&tensor_layout, distribute_operator, cnode, parameter, IntToSize(res.second)); + // create communication group for allgather operator slice_shape = tensor_layout.opt_shard_slice_shape(); + std::vector dev_group; + if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) == + Status::SUCCESS && + !dev_group.empty()) { + opt_shard_group = dev_group[0].name(); + // set communication group in tensor layout for checkpoint saving + tensor_layout.set_opt_shard_group(opt_shard_group); + MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString() + << " success."; + } else { + MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed."; + } + } else { + MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions."; } } MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " @@ -1451,6 +1488,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::paircast(); MS_EXCEPTION_IF_NULL(parameter_ptr); parameter_ptr->set_user_data(std::make_shared(tensor_layout)); + return opt_shard_group; } void CoverSliceShape(const FuncGraphPtr &root) { @@ -1460,14 +1498,18 @@ void CoverSliceShape(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(parameter->Shape()); auto iter = g_RefMap.find(parameter); if (iter != g_RefMap.end()) { - SetParallelShape(parameter, g_RefMap[parameter]); + std::string group = SetParallelShape(parameter, g_RefMap[parameter]); + // find all forward nodes that use parameter in graphs and insert allgather if group is not empty + ApplyParallelOptOnParam(root, parameter, group); continue; } std::pair res = FindSubGraph(root, parameter); if (res.first == nullptr) { MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; } else { - SetParallelShape(parameter, res); + std::string group = SetParallelShape(parameter, res); + // find all forward nodes that use parameter in graphs and insert allgather if group is not empty + ApplyParallelOptOnParam(root, parameter, group); MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); } } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 84a9aeb5fb..00f60b39b6 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -109,7 +109,7 @@ std::vector ExtractShape(const CNodePtr &node); std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); // Set distribute shape for parameters abstract -void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); +std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); // change parameters'shape in resource void CoverSliceShape(const FuncGraphPtr &root); diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 9497cb574b..76caacba7b 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -418,9 +418,9 @@ class Parameter(MetaTensor_): if _is_role_worker(): data = self.init_mode.to_tensor(0, [1]) else: - data = self.init_mode.to_tensor(slice_index, layout[2]) + data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) else: - data = self.init_mode.to_tensor(slice_index, layout[2]) + data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) else: if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): if _is_role_worker(): diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 21e03434a3..7b5ae75f24 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -16,6 +16,7 @@ import numpy as np from mindspore import log as logger +from mindspore.communication.management import get_rank, get_group_size from .._c_expression import Tensor as Tensor_ from .._c_expression import MetaTensor as MetaTensor_ from .._checkparam import check_type, check_typename @@ -409,7 +410,7 @@ class MetaTensor(MetaTensor_): self.init = init MetaTensor_.__init__(self, dtype, shape) - def to_tensor(self, slice_index=None, shape=None): + def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None): """ Get the tensor format data of this MetaTensor. @@ -418,6 +419,8 @@ class MetaTensor(MetaTensor_): It is used when initialize a slice of a parameter, it guarantees that devices using the same slice can generate the same tensor. shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. + opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode + to get one shard of a parameter's slice. """ if self.init is None: raise TypeError("to_dense must be set MetaTensor.init, init can't be None") @@ -453,7 +456,12 @@ class MetaTensor(MetaTensor_): with seed_context(self.init): self.init(arr) - return Tensor(arr, dtype=self.dtype) + data = np.array(arr) + if opt_shard_group: + rank = get_rank(opt_shard_group) + size = get_group_size(opt_shard_group) + data = np.split(data, size)[rank] + return Tensor(data, dtype=self.dtype) def _vm_compare(*args): diff --git a/mindspore/parallel/_cell_wrapper.py b/mindspore/parallel/_cell_wrapper.py index 9845301da8..88c67cbc8b 100644 --- a/mindspore/parallel/_cell_wrapper.py +++ b/mindspore/parallel/_cell_wrapper.py @@ -16,7 +16,7 @@ from mindspore.nn.cell import Cell from mindspore.ops.operations.comm_ops import AllGather - +from mindspore.communication import GlobalComm _allgather_cell = None @@ -26,10 +26,10 @@ class AllGatherCell(Cell): Allgather cell, used in model parallel scenario. To allgather the selected parameter slice from each device. """ - def __init__(self): + def __init__(self, group): super(AllGatherCell, self).__init__(auto_prefix=False) - self.allgather = AllGather() + self.allgather = AllGather(group) def construct(self, x): x = self.allgather(x) @@ -58,13 +58,16 @@ class SaveOptShardCkptCell(Cell): return x -def get_allgather_cell(group): +def get_allgather_cell(group, need_merge_twice=False): """Get AllGatherCell object.""" global _allgather_cell - if group: + if need_merge_twice: _allgather_cell = SaveOptShardCkptCell(group) else: - _allgather_cell = AllGatherCell() + if group: + _allgather_cell = AllGatherCell(group) + else: + _allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP) return _allgather_cell diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 3cbbaa902f..78dd215fce 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -456,13 +456,18 @@ def _get_merged_param_data(net, param_name, param_data): # while any dim is not equal to -1, means param is split and needs to be merged # pipeline parallel need to be supported here later for dim in tensor_map: - if dim != -1 or opt_shard_group: - allgather_net = get_allgather_cell(opt_shard_group) + if dim != -1: + if opt_shard_group: + allgather_net = get_allgather_cell(opt_shard_group, True) + else: + allgather_net = get_allgather_cell(opt_shard_group, False) param_data = allgather_net(param_data) if field_size: return _reshape_param_data_with_weight(param_data, dev_mat, field_size) return _reshape_param_data(param_data, dev_mat, tensor_map) - + if opt_shard_group: + allgather_net = get_allgather_cell(opt_shard_group, False) + param_data = allgather_net(param_data) return param_data diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index 21f39346c3..c850643745 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -77,8 +77,9 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): train_network = TrainOneStepCell(net, optimizer) train_network.set_auto_parallel() train_network.set_train() - _executor.compile(train_network, inputs, label) + _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True) context.reset_auto_parallel_context() + return train_network def test_auto_parallel_momentum_1(): @@ -93,7 +94,11 @@ def test_auto_parallel_momentum_2(): def test_auto_parallel_momentum_3(): # hybrid parallel case # weight1 could not be shard and weight2 is repeated - auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2))) + train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2))) + param_dict = train_network.parameter_layout_dict + # validate opt_shard_group + assert not param_dict["weight1"][5] + assert param_dict["weight2"][5].startswith("4") def test_auto_parallel_momentum_4():