update field split

pull/4356/head
yangzhenzhang 5 years ago
parent a7556d874d
commit 4a0e6ff7fc

@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array();
int32_t _field_size = tensor_layout->get_field_size();
Shape field_size;
if (_field_size != 0) {
field_size.push_back(_field_size);
Shape field_size = {tensor_layout->get_field_size()};
Shape uniform_split;
if (tensor_layout->uniform_split()) {
uniform_split.push_back(1);
} else {
field_size = {0};
uniform_split.push_back(0);
}
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size};
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split};
dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
}

@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo {
Status GetAttrs() override;
Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit();
Status CheckManualSplit(const Strategys &strategy);
Status GetManualSplitAttr();
Status GetManualSplitWithoutOffsetAttr();
Status ComputeReplaceOp();
Status InferBias();
Status InferOffset();

@ -48,6 +48,10 @@ class TensorLayout {
void set_field_size(int32_t field_size) { field_size_ = field_size; }
bool uniform_split() const { return uniform_split_; }
void set_uniform_split(bool flag) { uniform_split_ = flag; }
Arrangement device_arrangement() const { return device_arrangement_; }
Map tensor_map() const { return tensor_map_; }
@ -104,6 +108,7 @@ class TensorLayout {
Arrangement tensor_shape_;
bool skip_redistribution_ = false;
int32_t field_size_ = 0;
bool uniform_split_ = true;
};
} // namespace parallel
} // namespace mindspore

@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout):
"""
if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}".format(layout))
if len(layout) < 3:
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout))
if len(layout) < 5:
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
dev_mat = layout[0]
tensor_map = layout[1]
uniform_split = layout[4]
if uniform_split[0] == 0:
raise RuntimeError("The load tensor only support uniform split now")
if tensor.size() == 1:
return tensor
return _load_tensor(tensor, dev_mat, tensor_map)

@ -49,8 +49,8 @@ def test_get_parameter_layout():
net.set_auto_parallel()
exe = me._executor
exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1]
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
@ -22,40 +23,170 @@ from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
class Net(Cell):
def __init__(self, strategy1=None, strategy2=None, strategy3=None):
def __init__(self,
strategy1=None,
strategy2=None,
strategy3=None,
axis=0,
init_flag=True,
split_tuple=(4, 4),
split_string="manual_split",
param_shape=(8, 8)):
super().__init__()
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1)))
self.gatherv2.add_prim_attr(split_string, split_tuple)
self.mul = P.Mul().set_strategy(strategy2)
self.reshape = P.Reshape()
self.matmul = P.MatMul().set_strategy(strategy3)
self.matmul.add_prim_attr("forward_reduce_scatter", True)
self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param")
self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight")
self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight")
if init_flag:
self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param")
else:
self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param")
self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight")
self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight")
self.axis = axis
def construct(self, x, b):
out = self.gatherv2(self.param, x, 0)
out = self.gatherv2(self.param, x, self.axis)
out = self.mul(out, self.mul_weight)
out = self.reshape(out, (2, 256))
out = self.reshape(out, (8, 64))
out = self.matmul(out, self.matmul_weight)
return out
_x = Tensor(np.ones([2, 4]), dtype=ms.int32)
_x = Tensor(np.ones([8, 8]), dtype=ms.int32)
_b = Tensor(np.ones([64, 8]), dtype=ms.float32)
def compile_net(net):
context.set_context(save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x, _b)
_executor.compile(train_net, _x, _b, auto_parallel_mode=True)
context.reset_auto_parallel_context()
def test_neg_data_parallel():
context.set_context(save_graphs=True)
def test_normal_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
compile_net(net)
def test_normal_split2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 4, 1), (1, 4, 1))
strategy3 = ((1, 4), (4, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
strategy1 = ((4, 8), (1, 4))
strategy2 = ((1, 4, 8), (1, 4, 8))
strategy3 = ((1, 32), (32, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split_with_offset():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4)))
compile_net(net)
def test_auto_parallel_error():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
net = Net()
with pytest.raises(RuntimeError):
compile_net(net)
def test_axis_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, axis=1)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (8, 1))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (1, 8))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error4():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 8), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error5():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_split_tuple_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5)))
with pytest.raises(RuntimeError):
compile_net(net)
def test_parameter_use_tensor_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, init_flag=False)
with pytest.raises(RuntimeError):
compile_net(net)

Loading…
Cancel
Save