!9832 expose_allgather_fusion_to_users

From: @gong_zi_yan
Reviewed-by: 
Signed-off-by:
pull/9832/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b67aaf6773

@ -942,6 +942,29 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
return (type_id != kNumberTypeFloat32);
}
static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(comm_node);
MS_EXCEPTION_IF_NULL(param_node);
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
return;
}
auto param = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto param_info = param->param_info();
if (!param_info) {
MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
return;
}
int32_t fusion_type = param_info->comm_fusion();
attrs[FUSION] = MakeValue<int64_t>(fusion_type);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size();
@ -1006,11 +1029,19 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
} else {
for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index);
InsertNode(op, node, index, pre_node, func_graph, instance_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
}
}
@ -1342,7 +1373,8 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
return std::make_pair(nullptr, 0);
}
void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr &parameter) {
static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res,
const AnfNodePtr &parameter) {
Operator op = CreateAllGatherOp(group);
MS_EXCEPTION_IF_NULL(res.first);
MS_EXCEPTION_IF_NULL(parameter);
@ -1360,11 +1392,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int
}
// add fusion flag
MS_EXCEPTION_IF_NULL(allgather);
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
// enable fusion flag later when it's supported in backend
attrs["fusion"] = MakeValue<int64_t>(1);
prim->SetAttrs(attrs);
AddCommOpFusionType(allgather, parameter);
}
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
@ -1419,6 +1447,9 @@ std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNod
if (!ParameterRequireGrad(parameter)) {
// only trainable parameters need parallel optimizer
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
} else if (parameter->cast<ParameterPtr>()->param_info() &&
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
// get a totally shard tensor slice shape if the weight is repeated on devices
// and the shape of the first dimension could be divided

@ -29,6 +29,9 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
.def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server)
.def_property("layerwise_parallel", &ParamInfo::layerwise_parallel,
&ParamInfo::set_layerwise_parallel)
.def_property("parallel_optimizer", &ParamInfo::parallel_optimizer,
&ParamInfo::set_parallel_optimizer)
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
.def(py::pickle(
[](const ParamInfo &p) { // __getstate__
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());

@ -75,8 +75,11 @@ class Parameter(MetaTensor_):
default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized.
name (str): Name of the child parameter. Default: None.
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode,
layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode,
broadcast and gradients communication would not be applied to parameters. Default: False.
parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel
mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
Default: True.
Example:
>>> from mindspore import Parameter, Tensor
@ -132,19 +135,21 @@ class Parameter(MetaTensor_):
return (
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False):
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
self._param_info = ParamInfo()
self.init_in_server = False
self.cache_enable = False
self.name = name
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self.parallel_optimizer = parallel_optimizer
# this flag for tensor copy data.
self.init_flag = False
# this flag is for ge variable copy data.
self._is_init = False
self._inited_param = None
self._sliced = False
self.comm_fusion = 1
self.is_param_ps = False
self._cast_type = None
self._unique = False
@ -210,7 +215,6 @@ class Parameter(MetaTensor_):
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
1. set_ps_context(enable_ps=True) \
2. export MS_ROLE environment variable.")
if init_in_server and (not self.name.endswith("embedding_table")):
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
"sparse operator support initialization in server.".format(self.name))
@ -218,7 +222,6 @@ class Parameter(MetaTensor_):
self.init_in_server = init_in_server
self._param_info.init_in_server = init_in_server
@property
def inited_param(self):
"""
@ -273,6 +276,16 @@ class Parameter(MetaTensor_):
def sliced(self, sliced_):
self._sliced = sliced_
@property
def comm_fusion(self):
"""Get the fusion type for communication operators corresponding to this parameter."""
return self._param_info.comm_fusion
@comm_fusion.setter
def comm_fusion(self, comm_fusion_):
"""Set the fusion type for communication operators corresponding to this parameter."""
self._param_info.comm_fusion = comm_fusion_
@property
def unique(self):
"""whether the parameter is already unique or not."""
@ -338,6 +351,17 @@ class Parameter(MetaTensor_):
raise TypeError("`layerwise_parallel` parameter must be bool type")
self._param_info.layerwise_parallel = value
@property
def parallel_optimizer(self):
"""Return whether the parameter requires weight shard for parallel optimizer."""
return self._param_info.parallel_optimizer
@parallel_optimizer.setter
def parallel_optimizer(self, value=True):
if not isinstance(value, bool):
raise TypeError("`parallel_optimizer` parameter must be bool type")
self._param_info.parallel_optimizer = value
@property
def requires_grad(self):
"""Return whether the parameter requires gradient."""

@ -75,6 +75,12 @@ class ParamInfo {
return clone;
}
int32_t comm_fusion() const { return fusion_type_; }
void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; }
bool parallel_optimizer() const { return parallel_optimizer_; }
void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; }
private:
std::string name_{"Parameter"};
bool requires_grad_{true};
@ -84,6 +90,8 @@ class ParamInfo {
bool cloned_{false};
std::vector<int32_t> be_cloned_index_;
int32_t cloned_index_{0};
int32_t fusion_type_{1};
bool parallel_optimizer_{true};
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_

@ -1087,6 +1087,12 @@ class Cell(Cell_):
for param in params:
param.set_param_ps(init_in_server)
def set_comm_fusion(self, fusion_type, recurse=True):
Validator.check_is_int(fusion_type)
for param in self.trainable_params(recurse):
param.comm_fusion = fusion_type
return self
class GraphKernel(Cell):
"""

@ -127,7 +127,7 @@ def get_bprop_all_gather(self):
instance_name = "grad_" + self.instance_name
reduce_scatter.set_prim_instance_name(instance_name)
else:
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1)
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
@ -242,9 +242,7 @@ def get_bprop_mirror_operator(self):
mul = P.Mul()
cast = P.Cast()
fusion = 1
if hasattr(self, 'fusion'):
fusion = self.fusion
fusion = self.get_attr_dict()["fusion"]
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter

@ -555,6 +555,7 @@ class _MirrorOperator(PrimitiveWithInfer):
self.group = group
self.dev_num = dev_num
self.mean_flag = mean_flag
self.add_prim_attr("fusion", 1)
def infer_shape(self, x_shape):
return x_shape

@ -25,6 +25,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
import pytest
class Dataset(MindData):
@ -125,6 +126,7 @@ def train_common(net):
return allreduce_fusion_dict
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion_parameters():
cost_model_context.reset_cost_model_context()
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
@ -181,6 +183,7 @@ def test_allreduce_fusion_parameters():
assert computation_time_parameter == 0.1
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion1():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
@ -205,6 +208,7 @@ def test_allreduce_fusion1():
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
# reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion
# is bypassed.
def test_allreduce_fusion2():
@ -220,6 +224,7 @@ def test_allreduce_fusion2():
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion3():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3)
@ -248,6 +253,7 @@ def test_allreduce_fusion3():
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion4():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
@ -277,6 +283,7 @@ def test_allreduce_fusion4():
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion5():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)

@ -66,15 +66,30 @@ class Net2(nn.Cell):
return x - y
def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
class Net3(nn.Cell):
"""Net definition"""
def __init__(self, strategy1, strategy2):
super(Net3, self).__init__()
self.fc1 = P.MatMul().shard(strategy=strategy1)
self.fc2 = P.MatMul().shard(strategy=strategy2)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
def construct(self, x, y):
x = self.fc1(x, self.p1)
x = self.fc2(x, self.p2)
return x - y
def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None):
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True)
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
label = Tensor(np.zeros([32, 16]).astype(np.float32))
net = Net2(strategy1, strategy2)
net = net(strategy1, strategy2)
net = _VirtualDatasetCell(net)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_network = TrainOneStepCell(net, optimizer)
train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
train_network.set_auto_parallel()
train_network.set_train()
_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
@ -83,18 +98,18 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
def test_auto_parallel_momentum_1():
auto_parallel_compile_net("auto_parallel", 8)
auto_parallel_compile_net("auto_parallel", 8, Net2)
def test_auto_parallel_momentum_2():
# data parallel case
auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
def test_auto_parallel_momentum_3():
# hybrid parallel case
# weight1 could not be shard and weight2 is repeated
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
param_dict = train_network.parameter_layout_dict
# validate opt_shard_group
assert not param_dict["weight1"][5]
@ -104,7 +119,16 @@ def test_auto_parallel_momentum_3():
def test_auto_parallel_momentum_4():
# hybrid parallel cases
# devices are repeatedly used
auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
def test_auto_parallel_momentum_5():
# test parallel optimizer filter
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
param_dict = train_network.parameter_layout_dict
# validate opt_shard_group
assert not param_dict["weight1"][5]
assert not param_dict["weight2"][5]
def test_AdamWeightDecay():

Loading…
Cancel
Save