diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index f823a302c2..0c10896e3e 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -942,6 +942,29 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) { return (type_id != kNumberTypeFloat32); } +static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { + MS_EXCEPTION_IF_NULL(comm_node); + MS_EXCEPTION_IF_NULL(param_node); + if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { + MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; + return; + } + auto param = param_node->cast(); + MS_EXCEPTION_IF_NULL(param); + auto prim = GetValueNode(comm_node->input(0)); + MS_EXCEPTION_IF_NULL(prim); + auto attrs = prim->attrs(); + auto param_info = param->param_info(); + if (!param_info) { + MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; + return; + } + int32_t fusion_type = param_info->comm_fusion(); + attrs[FUSION] = MakeValue(fusion_type); + prim->SetAttrs(attrs); + MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; +} + void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); @@ -1006,11 +1029,19 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(cnode); AnfNodePtr pre_node = cnode->input(1); InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); + auto comm_op = cnode->input(size_t(1))->cast(); + // add fusion flag + // pipeline mirror would not be set, which should be supported later + AddCommOpFusionType(comm_op, param_node_pair.first); } } else { for (auto &op : backward_op) { AnfNodePtr pre_node = node->input(index); InsertNode(op, node, index, pre_node, func_graph, instance_name); + auto comm_op = node->input(index)->cast(); + // add fusion flag + // pipeline mirror would not be set, which should be supported later + AddCommOpFusionType(comm_op, param_node_pair.first); } } } @@ -1342,7 +1373,8 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const Anf return std::make_pair(nullptr, 0); } -void InsertAllGatherOp(const std::string &group, const std::pair &res, const AnfNodePtr ¶meter) { +static void InsertAllGatherOp(const std::string &group, const std::pair &res, + const AnfNodePtr ¶meter) { Operator op = CreateAllGatherOp(group); MS_EXCEPTION_IF_NULL(res.first); MS_EXCEPTION_IF_NULL(parameter); @@ -1360,11 +1392,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair(allgather->input(0)); - auto attrs = prim->attrs(); - // enable fusion flag later when it's supported in backend - attrs["fusion"] = MakeValue(1); - prim->SetAttrs(attrs); + AddCommOpFusionType(allgather, parameter); } static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, @@ -1419,6 +1447,9 @@ std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pairToString() << " is not trainable parameter."; + } else if (parameter->cast()->param_info() && + !parameter->cast()->param_info()->parallel_optimizer()) { + MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { // get a totally shard tensor slice shape if the weight is repeated on devices // and the shape of the first dimension could be divided diff --git a/mindspore/ccsrc/pybind_api/ir/param_info_py.cc b/mindspore/ccsrc/pybind_api/ir/param_info_py.cc index 74f2730cf3..ea3c9df21f 100644 --- a/mindspore/ccsrc/pybind_api/ir/param_info_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/param_info_py.cc @@ -29,6 +29,9 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { .def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server) .def_property("layerwise_parallel", &ParamInfo::layerwise_parallel, &ParamInfo::set_layerwise_parallel) + .def_property("parallel_optimizer", &ParamInfo::parallel_optimizer, + &ParamInfo::set_parallel_optimizer) + .def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion) .def(py::pickle( [](const ParamInfo &p) { // __getstate__ return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index f27b1b7825..a53bbf17ad 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -75,8 +75,11 @@ class Parameter(MetaTensor_): default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized. name (str): Name of the child parameter. Default: None. requires_grad (bool): True if the parameter requires gradient. Default: True. - layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode, + layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode, broadcast and gradients communication would not be applied to parameters. Default: False. + parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel + mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`. + Default: True. Example: >>> from mindspore import Parameter, Tensor @@ -132,19 +135,21 @@ class Parameter(MetaTensor_): return ( Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) - def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False): + def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): self._param_info = ParamInfo() self.init_in_server = False self.cache_enable = False self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel + self.parallel_optimizer = parallel_optimizer # this flag for tensor copy data. self.init_flag = False # this flag is for ge variable copy data. self._is_init = False self._inited_param = None self._sliced = False + self.comm_fusion = 1 self.is_param_ps = False self._cast_type = None self._unique = False @@ -210,7 +215,6 @@ class Parameter(MetaTensor_): raise RuntimeError("Must complete following two steps before calling set_param_ps: \ 1. set_ps_context(enable_ps=True) \ 2. export MS_ROLE environment variable.") - if init_in_server and (not self.name.endswith("embedding_table")): raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " "sparse operator support initialization in server.".format(self.name)) @@ -218,7 +222,6 @@ class Parameter(MetaTensor_): self.init_in_server = init_in_server self._param_info.init_in_server = init_in_server - @property def inited_param(self): """ @@ -273,6 +276,16 @@ class Parameter(MetaTensor_): def sliced(self, sliced_): self._sliced = sliced_ + @property + def comm_fusion(self): + """Get the fusion type for communication operators corresponding to this parameter.""" + return self._param_info.comm_fusion + + @comm_fusion.setter + def comm_fusion(self, comm_fusion_): + """Set the fusion type for communication operators corresponding to this parameter.""" + self._param_info.comm_fusion = comm_fusion_ + @property def unique(self): """whether the parameter is already unique or not.""" @@ -338,6 +351,17 @@ class Parameter(MetaTensor_): raise TypeError("`layerwise_parallel` parameter must be bool type") self._param_info.layerwise_parallel = value + @property + def parallel_optimizer(self): + """Return whether the parameter requires weight shard for parallel optimizer.""" + return self._param_info.parallel_optimizer + + @parallel_optimizer.setter + def parallel_optimizer(self, value=True): + if not isinstance(value, bool): + raise TypeError("`parallel_optimizer` parameter must be bool type") + self._param_info.parallel_optimizer = value + @property def requires_grad(self): """Return whether the parameter requires gradient.""" diff --git a/mindspore/core/ir/param_info.h b/mindspore/core/ir/param_info.h index 216b1f025c..4d1b837b72 100644 --- a/mindspore/core/ir/param_info.h +++ b/mindspore/core/ir/param_info.h @@ -75,6 +75,12 @@ class ParamInfo { return clone; } + int32_t comm_fusion() const { return fusion_type_; } + void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; } + + bool parallel_optimizer() const { return parallel_optimizer_; } + void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } + private: std::string name_{"Parameter"}; bool requires_grad_{true}; @@ -84,6 +90,8 @@ class ParamInfo { bool cloned_{false}; std::vector be_cloned_index_; int32_t cloned_index_{0}; + int32_t fusion_type_{1}; + bool parallel_optimizer_{true}; }; } // namespace mindspore #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index a126766a90..b1072ac97d 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1075,6 +1075,12 @@ class Cell(Cell_): for param in params: param.set_param_ps(init_in_server) + def set_comm_fusion(self, fusion_type, recurse=True): + Validator.check_is_int(fusion_type) + for param in self.trainable_params(recurse): + param.comm_fusion = fusion_type + return self + class GraphKernel(Cell): """ diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 6675bbba03..d730458bbe 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -125,7 +125,7 @@ def get_bprop_all_gather(self): instance_name = "grad_" + self.instance_name reduce_scatter.set_prim_instance_name(instance_name) else: - all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1) + all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) if self.instance_name: instance_name = "grad_" + self.instance_name all_reduce.set_prim_instance_name(instance_name) @@ -240,9 +240,7 @@ def get_bprop_mirror_operator(self): mul = P.Mul() cast = P.Cast() - fusion = 1 - if hasattr(self, 'fusion'): - fusion = self.fusion + fusion = self.get_attr_dict()["fusion"] all_reduce.add_prim_attr("fusion", fusion) if hasattr(self, 'parameter'): parameter = self.parameter diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index c7e46fdf53..92f8fdec75 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -534,6 +534,7 @@ class _MirrorOperator(PrimitiveWithInfer): self.group = group self.dev_num = dev_num self.mean_flag = mean_flag + self.add_prim_attr("fusion", 1) def infer_shape(self, x_shape): return x_shape diff --git a/tests/ut/python/parallel/test_allreduce_fusion.py b/tests/ut/python/parallel/test_allreduce_fusion.py index 0bb30b2ae9..659560ac97 100644 --- a/tests/ut/python/parallel/test_allreduce_fusion.py +++ b/tests/ut/python/parallel/test_allreduce_fusion.py @@ -25,6 +25,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train import Model from mindspore.context import ParallelMode from tests.dataset_mock import MindData +import pytest class Dataset(MindData): @@ -125,6 +126,7 @@ def train_common(net): return allreduce_fusion_dict +@pytest.mark.skip(reason="depreciated feature") def test_allreduce_fusion_parameters(): cost_model_context.reset_cost_model_context() cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) @@ -181,6 +183,7 @@ def test_allreduce_fusion_parameters(): assert computation_time_parameter == 0.1 +@pytest.mark.skip(reason="depreciated feature") def test_allreduce_fusion1(): cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) @@ -205,6 +208,7 @@ def test_allreduce_fusion1(): cost_model_context.reset_cost_model_context() +@pytest.mark.skip(reason="depreciated feature") # reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion # is bypassed. def test_allreduce_fusion2(): @@ -220,6 +224,7 @@ def test_allreduce_fusion2(): cost_model_context.reset_cost_model_context() +@pytest.mark.skip(reason="depreciated feature") def test_allreduce_fusion3(): cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3) @@ -248,6 +253,7 @@ def test_allreduce_fusion3(): cost_model_context.reset_cost_model_context() +@pytest.mark.skip(reason="depreciated feature") def test_allreduce_fusion4(): cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) @@ -277,6 +283,7 @@ def test_allreduce_fusion4(): cost_model_context.reset_cost_model_context() +@pytest.mark.skip(reason="depreciated feature") def test_allreduce_fusion5(): cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index c850643745..6659d7c3be 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -66,15 +66,30 @@ class Net2(nn.Cell): return x - y -def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): +class Net3(nn.Cell): + """Net definition""" + def __init__(self, strategy1, strategy2): + super(Net3, self).__init__() + self.fc1 = P.MatMul().shard(strategy=strategy1) + self.fc2 = P.MatMul().shard(strategy=strategy2) + self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") + self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False) + + def construct(self, x, y): + x = self.fc1(x, self.p1) + x = self.fc2(x, self.p2) + return x - y + + +def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None): context.set_context(mode=context.GRAPH_MODE) context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 48]).astype(np.float32)) label = Tensor(np.zeros([32, 16]).astype(np.float32)) - net = Net2(strategy1, strategy2) + net = net(strategy1, strategy2) net = _VirtualDatasetCell(net) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) - train_network = TrainOneStepCell(net, optimizer) + train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4) train_network.set_auto_parallel() train_network.set_train() _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True) @@ -83,18 +98,18 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): def test_auto_parallel_momentum_1(): - auto_parallel_compile_net("auto_parallel", 8) + auto_parallel_compile_net("auto_parallel", 8, Net2) def test_auto_parallel_momentum_2(): # data parallel case - auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1))) + auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1))) def test_auto_parallel_momentum_3(): # hybrid parallel case # weight1 could not be shard and weight2 is repeated - train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2))) + train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) param_dict = train_network.parameter_layout_dict # validate opt_shard_group assert not param_dict["weight1"][5] @@ -104,7 +119,16 @@ def test_auto_parallel_momentum_3(): def test_auto_parallel_momentum_4(): # hybrid parallel cases # devices are repeatedly used - auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2))) + auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2))) + + +def test_auto_parallel_momentum_5(): + # test parallel optimizer filter + train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2))) + param_dict = train_network.parameter_layout_dict + # validate opt_shard_group + assert not param_dict["weight1"][5] + assert not param_dict["weight2"][5] def test_AdamWeightDecay():