!4356 Add validation for field split

Merge pull request !4356 from yangzhenzhang/update-field-split
pull/4356/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2db0290c49

@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto device_arrangement = tensor_layout->device_arrangement().array(); auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array(); auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array(); auto slice_shape = tensor_layout->slice_shape().array();
int32_t _field_size = tensor_layout->get_field_size(); Shape field_size = {tensor_layout->get_field_size()};
Shape field_size; Shape uniform_split;
if (_field_size != 0) { if (tensor_layout->uniform_split()) {
field_size.push_back(_field_size); uniform_split.push_back(1);
} else { } else {
field_size = {0}; uniform_split.push_back(0);
} }
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size};
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split};
dict[py::str(name)] = layout; dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
} }

@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo {
Status GetAttrs() override; Status GetAttrs() override;
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit(); Status CheckManualSplit(const Strategys &strategy);
Status GetManualSplitAttr();
Status GetManualSplitWithoutOffsetAttr();
Status ComputeReplaceOp(); Status ComputeReplaceOp();
Status InferBias(); Status InferBias();
Status InferOffset(); Status InferOffset();

@ -48,6 +48,10 @@ class TensorLayout {
void set_field_size(int32_t field_size) { field_size_ = field_size; } void set_field_size(int32_t field_size) { field_size_ = field_size; }
bool uniform_split() const { return uniform_split_; }
void set_uniform_split(bool flag) { uniform_split_ = flag; }
Arrangement device_arrangement() const { return device_arrangement_; } Arrangement device_arrangement() const { return device_arrangement_; }
Map tensor_map() const { return tensor_map_; } Map tensor_map() const { return tensor_map_; }
@ -104,6 +108,7 @@ class TensorLayout {
Arrangement tensor_shape_; Arrangement tensor_shape_;
bool skip_redistribution_ = false; bool skip_redistribution_ = false;
int32_t field_size_ = 0; int32_t field_size_ = 0;
bool uniform_split_ = true;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout):
""" """
if not isinstance(layout, list): if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}".format(layout)) raise TypeError("The layout should be list! layout is {}".format(layout))
if len(layout) < 3: if len(layout) < 5:
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
dev_mat = layout[0] dev_mat = layout[0]
tensor_map = layout[1] tensor_map = layout[1]
uniform_split = layout[4]
if uniform_split[0] == 0:
raise RuntimeError("The load tensor only support uniform split now")
if tensor.size() == 1: if tensor.size() == 1:
return tensor return tensor
return _load_tensor(tensor, dev_mat, tensor_map) return _load_tensor(tensor, dev_mat, tensor_map)

@ -49,8 +49,8 @@ def test_get_parameter_layout():
net.set_auto_parallel() net.set_auto_parallel()
exe = me._executor exe = me._executor
exe.compile(net, x, phase='train', auto_parallel_mode=True) exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout} expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict assert net.parameter_layout_dict == expect_dict

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore as ms import mindspore as ms
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor from mindspore.common.api import _executor
@ -22,40 +23,170 @@ from mindspore.ops import operations as P
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
class Net(Cell): class Net(Cell):
def __init__(self, strategy1=None, strategy2=None, strategy3=None): def __init__(self,
strategy1=None,
strategy2=None,
strategy3=None,
axis=0,
init_flag=True,
split_tuple=(4, 4),
split_string="manual_split",
param_shape=(8, 8)):
super().__init__() super().__init__()
self.gatherv2 = P.GatherV2().set_strategy(strategy1) self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1))) self.gatherv2.add_prim_attr(split_string, split_tuple)
self.mul = P.Mul().set_strategy(strategy2) self.mul = P.Mul().set_strategy(strategy2)
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.matmul = P.MatMul().set_strategy(strategy3) self.matmul = P.MatMul().set_strategy(strategy3)
self.matmul.add_prim_attr("forward_reduce_scatter", True) self.matmul.add_prim_attr("forward_reduce_scatter", True)
self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param") if init_flag:
self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight") self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param")
self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight") else:
self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param")
self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight")
self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight")
self.axis = axis
def construct(self, x, b): def construct(self, x, b):
out = self.gatherv2(self.param, x, 0) out = self.gatherv2(self.param, x, self.axis)
out = self.mul(out, self.mul_weight) out = self.mul(out, self.mul_weight)
out = self.reshape(out, (2, 256)) out = self.reshape(out, (8, 64))
out = self.matmul(out, self.matmul_weight) out = self.matmul(out, self.matmul_weight)
return out return out
_x = Tensor(np.ones([2, 4]), dtype=ms.int32)
_x = Tensor(np.ones([8, 8]), dtype=ms.int32)
_b = Tensor(np.ones([64, 8]), dtype=ms.float32) _b = Tensor(np.ones([64, 8]), dtype=ms.float32)
def compile_net(net): def compile_net(net):
context.set_context(save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer) train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel() train_net.set_auto_parallel()
_executor.compile(train_net, _x, _b) _executor.compile(train_net, _x, _b, auto_parallel_mode=True)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
def test_neg_data_parallel():
def test_normal_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
compile_net(net)
def test_normal_split2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 4, 1), (1, 4, 1))
strategy3 = ((1, 4), (4, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
strategy1 = ((4, 8), (1, 4))
strategy2 = ((1, 4, 8), (1, 4, 8))
strategy3 = ((1, 32), (32, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split_with_offset():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4)))
compile_net(net)
def test_auto_parallel_error():
context.set_context(save_graphs=True) context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
net = Net()
with pytest.raises(RuntimeError):
compile_net(net)
def test_axis_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2)) strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1)) strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1)) strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, axis=1)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (8, 1))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3) net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (1, 8))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error4():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 8), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error5():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_split_tuple_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5)))
with pytest.raises(RuntimeError):
compile_net(net)
def test_parameter_use_tensor_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, init_flag=False)
with pytest.raises(RuntimeError):
compile_net(net) compile_net(net)

Loading…
Cancel
Save