change API set_strategy() to shard()

pull/5991/head
Yi Huaijie 4 years ago
parent 20991a0718
commit a836d25c64

@ -567,8 +567,8 @@ class Conv2dTranspose(_Conv):
else:
self.padding_top, self.padding_bottom, self.padding_left, self.padding_right = self.padding
def set_strategy(self, strategy):
self.conv2d_transpose.set_strategy(strategy)
def shard(self, strategy):
self.conv2d_transpose.shard(strategy)
return self
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding):
@ -744,8 +744,8 @@ class Conv1dTranspose(_Conv):
self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze(2)
def set_strategy(self, strategy):
self.conv2d_transpose.set_strategy(strategy)
def shard(self, strategy):
self.conv2d_transpose.shard(strategy)
return self
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding):

@ -174,17 +174,17 @@ class EmbeddingLookup(Cell):
Validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name)
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size())))
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == "table_row_slice" and is_auto_parallel:
self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1)))
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
self.gatherv2.shard(((get_group_size(), 1), (1, 1)))
self.embeddinglookup.shard(((get_group_size(), 1), (1, 1)))
elif slice_mode == "table_column_slice" and is_auto_parallel:
self.gatherv2.set_strategy(((1, get_group_size()), (1, 1)))
self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
self.gatherv2.shard(((1, get_group_size()), (1, 1)))
self.embeddinglookup.shard(((1, get_group_size()), (1, 1)))
elif slice_mode == "batch_slice" and is_auto_parallel:
self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1)))
self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1)))
self.gatherv2.shard(((1, 1), (get_group_size(), 1)))
self.embeddinglookup.shard(((1, 1), (get_group_size(), 1)))
else:
if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "

@ -112,12 +112,12 @@ class _BatchNorm(Cell):
data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())
self.sub_mean = P.Sub().set_strategy(data_parallel_strategy)
self.sub_var = P.Sub().set_strategy(data_parallel_strategy)
self.mul_mean = P.Mul().set_strategy(data_parallel_strategy_one)
self.mul_var = P.Mul().set_strategy(data_parallel_strategy_one)
self.assign_sub_mean = P.AssignSub().set_strategy(data_parallel_strategy)
self.assign_sub_var = P.AssignSub().set_strategy(data_parallel_strategy)
self.sub_mean = P.Sub().shard(data_parallel_strategy)
self.sub_var = P.Sub().shard(data_parallel_strategy)
self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
self.mul_var = P.Mul().shard(data_parallel_strategy_one)
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
def _check_data_dim(self, x):
raise NotImplementedError

@ -102,7 +102,7 @@ class Primitive(Primitive_):
self.add_attr(name, value)
return self
def set_strategy(self, strategy):
def shard(self, strategy):
"""
Add strategies to primitive attribute.

@ -198,14 +198,14 @@ class WideDeepModel(nn.Cell):
self.concat = P.Concat(axis=1)
self.cast = P.Cast()
if is_auto_parallel and host_device_mix and not is_field_slice:
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),))
self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1)))
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE)
self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_reshape.add_prim_attr("skip_redistribution", True)
self.reduce_sum.add_prim_attr("cross_batch", True)
self.embedding_table = self.deep_embeddinglookup.embedding_table
@ -217,12 +217,12 @@ class WideDeepModel(nn.Cell):
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
slice_mode=nn.EmbeddingLookup.FIELD_SLICE,
manual_shapes=manual_shapes)
self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
self.reduce_sum.set_strategy(((1, get_group_size(), 1),))
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
self.deep_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1)))
self.wide_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1)))
self.reduce_sum.shard(((1, get_group_size(), 1),))
self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),))
self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1)))
self.embedding_table = self.deep_embeddinglookup.embedding_table
elif parameter_server:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)

@ -51,12 +51,12 @@ class Onehot(Cell):
trans_stra = None
if strategy:
trans_stra = (strategy[0],)
self.onehot = P.OneHot().set_strategy(strategy=strategy)
self.onehot = P.OneHot().shard(strategy=strategy)
self.depth = depth
self.on_value = Tensor(on_value, ms.float32)
self.off_value = Tensor(off_value, ms.float32)
self.transpose = P.Transpose().set_strategy(strategy=trans_stra)
self.sub = P.Sub().set_strategy(strategy=((1, 1), (1, 1)))
self.transpose = P.Transpose().shard(strategy=trans_stra)
self.sub = P.Sub().shard(strategy=((1, 1), (1, 1)))
self.axis = axis
def construct(self, input_, indices):

@ -140,20 +140,20 @@ class SoftmaxCrossEntropyExpand(Cell):
if len(stra_list) < 11:
stra_list = [None] * 11
self.exp = P.Exp()
self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy=stra_list[1])
self.onehot = P.OneHot().set_strategy(strategy=stra_list[2])
self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy=stra_list[1])
self.onehot = P.OneHot().shard(strategy=stra_list[2])
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.div = P.Div().set_strategy(strategy=stra_list[3])
self.log = P.Log().set_strategy(strategy=stra_list[4])
self.sum_cross_entropy = P.ReduceSum(keep_dims=False).set_strategy(strategy=stra_list[5])
self.mul = P.Mul().set_strategy(strategy=stra_list[6])
self.mul2 = P.Mul().set_strategy(strategy=stra_list[7])
self.div = P.Div().shard(strategy=stra_list[3])
self.log = P.Log().shard(strategy=stra_list[4])
self.sum_cross_entropy = P.ReduceSum(keep_dims=False).shard(strategy=stra_list[5])
self.mul = P.Mul().shard(strategy=stra_list[6])
self.mul2 = P.Mul().shard(strategy=stra_list[7])
self.cast = P.Cast()
self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy=stra_list[8])
self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy=stra_list[8])
self.sparse = sparse
self.reduce_max = P.ReduceMax(keep_dims=True).set_strategy(strategy=stra_list[9])
self.sub = P.Sub().set_strategy(strategy=stra_list[10])
self.reduce_max = P.ReduceMax(keep_dims=True).shard(strategy=stra_list[9])
self.sub = P.Sub().shard(strategy=stra_list[10])
def construct(self, logit, label):
logit_max = self.reduce_max(logit, -1)
@ -174,7 +174,7 @@ class MatmulNet(Cell):
super(MatmulNet, self).__init__()
if loss_stra_list is None:
loss_stra_list = []
self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy=matmul_stra)
self.matmul = P.MatMul(transpose_b=True).shard(strategy=matmul_stra)
self.loss = SoftmaxCrossEntropyExpand(sparse=True, stra_list=loss_stra_list)
self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight")

@ -181,7 +181,7 @@ class WideDeepModel(nn.Cell):
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.gather_v2 = P.GatherV2().set_strategy(((1, 8), (1, 1)))
self.gather_v2 = P.GatherV2().shard(((1, 8), (1, 1)))
self.gather_v2_1 = P.GatherV2()
self.mul = P.Mul()
self.reduce_sum = P.ReduceSum(keep_dims=False)
@ -230,7 +230,7 @@ class NetWithLossClass(nn.Cell):
self.network = network
self.l2_coef = config.l2_coef
self.loss = P.SigmoidCrossEntropyWithLogits()
self.square = P.Square().set_strategy(((1, get_group_size()),))
self.square = P.Square().shard(((1, get_group_size()),))
self.reduceMean_false = P.ReduceMean(keep_dims=False)
self.reduceSum_false = P.ReduceSum(keep_dims=False)

@ -273,8 +273,8 @@ class DepthwiseConv2dNative(_DepthwiseConv2dNative):
dilation=self.dilation,
group=self.group)
def set_strategy(self, strategy):
self.depthwise_conv2d_native.set_strategy(strategy)
def shard(self, strategy):
self.depthwise_conv2d_native.shard(strategy)
return self
def construct(self, x):

@ -29,8 +29,8 @@ grad_all = C.GradOperation(get_all=True)
class AddRelu(nn.Cell):
def __init__(self, strategy0=None, strategy1=None):
super(AddRelu, self).__init__()
self.add = P.TensorAdd().set_strategy(strategy=strategy0)
self.relu = P.ReLU().set_strategy(strategy=strategy1)
self.add = P.TensorAdd().shard(strategy=strategy0)
self.relu = P.ReLU().shard(strategy=strategy1)
def construct(self, x, z):
out = self.add(x, z)

@ -53,9 +53,9 @@ class Dataset(MindData):
class AllToAllNet(nn.Cell):
def __init__(self, strategy1):
super(AllToAllNet, self).__init__()
self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8)))
self.matmul = P.MatMul().shard(((1, 1), (1, 8)))
self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
self.transpose1 = P.Transpose().set_strategy(strategy1)
self.transpose1 = P.Transpose().shard(strategy1)
def construct(self, x):
x = self.matmul(x, self.matmul_weight)
@ -80,8 +80,8 @@ def all_to_all_common(strategy1):
net = all_to_all_net(strategy1)
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
loss.one_hot.set_strategy(((8, 1), (), ()))
loss.softmax_cross_entropy.shard(((8, 1), (8, 1)))
loss.one_hot.shard(((8, 1), (), ()))
opt = Momentum(net.trainable_params(), learning_rate, momentum)
model = Model(net, loss, opt)

@ -55,8 +55,8 @@ def test_matmul_sub():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.sub = P.Sub().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.sub = P.Sub().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -79,8 +79,8 @@ def test_matmul_add():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.add = P.TensorAdd().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.add = P.TensorAdd().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -103,8 +103,8 @@ def test_matmul_mul():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.mul = P.Mul().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.mul = P.Mul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -126,8 +126,8 @@ def test_matmul_mod():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.mod = P.Mod().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.mod = P.Mod().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -149,8 +149,8 @@ def test_matmul_floormod():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.floormod = P.FloorMod().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.floormod = P.FloorMod().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -173,8 +173,8 @@ def test_matmul_atan2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.atan2 = P.Atan2().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.atan2 = P.Atan2().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -197,8 +197,8 @@ def test_matmul_divNoNan():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.divNoNan = P.DivNoNan().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.divNoNan = P.DivNoNan().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -221,10 +221,10 @@ def test_matmul_logicaland():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.equal = P.Equal().set_strategy(strategy2)
self.notequal = P.NotEqual().set_strategy(strategy2)
self.logical = P.LogicalAnd().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.equal = P.Equal().shard(strategy2)
self.notequal = P.NotEqual().shard(strategy2)
self.logical = P.LogicalAnd().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -250,10 +250,10 @@ def test_matmul_logicalor():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.equal = P.Equal().set_strategy(strategy2)
self.notequal = P.NotEqual().set_strategy(strategy2)
self.logical = P.LogicalOr().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.equal = P.Equal().shard(strategy2)
self.notequal = P.NotEqual().shard(strategy2)
self.logical = P.LogicalOr().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -279,8 +279,8 @@ def test_matmul_div():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.div = P.Div().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.div = P.Div().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -303,8 +303,8 @@ def test_matmul_add_broadcast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.add = P.TensorAdd().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.add = P.TensorAdd().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -327,8 +327,8 @@ def test_matmul_add_broadcast2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.add = P.TensorAdd().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.add = P.TensorAdd().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -351,8 +351,8 @@ def test_matmul_sub_broadcast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.sub = P.Sub().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.sub = P.Sub().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -375,8 +375,8 @@ def test_matmul_sub_broadcast2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.sub = P.Sub().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.sub = P.Sub().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -399,8 +399,8 @@ def test_matmul_mul_broadcast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.mul = P.Mul().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.mul = P.Mul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -423,8 +423,8 @@ def test_matmul_mul_broadcast2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.mul = P.Mul().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.mul = P.Mul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -447,8 +447,8 @@ def test_matmul_div_broadcast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.div = P.Div().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.div = P.Div().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -471,8 +471,8 @@ def test_matmul_div_broadcast2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.div = P.Div().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.div = P.Div().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -495,8 +495,8 @@ def test_matmul_greater_broadcast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.greater = P.Greater().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.greater = P.Greater().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -519,8 +519,8 @@ def test_matmul_greater_broadcast2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.greater = P.Greater().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.greater = P.Greater().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -543,8 +543,8 @@ def test_matmul_floordiv():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.floordiv = P.FloorDiv().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.floordiv = P.FloorDiv().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -567,8 +567,8 @@ def test_matmul_floordiv_broadcast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.floordiv = P.FloorDiv().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.floordiv = P.FloorDiv().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
@ -591,8 +591,8 @@ def test_matmul_floordiv_broadcast2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.floordiv = P.FloorDiv().set_strategy(strategy2)
self.matmul = P.MatMul().shard(strategy1)
self.floordiv = P.FloorDiv().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)

@ -60,18 +60,18 @@ class Net(nn.Cell):
super().__init__()
self.query_w = Parameter(initializer(
"normal", [8, 16], ms.float32), name='query')
self.query = P.MatMul().set_strategy(strategy1)
self.query = P.MatMul().shard(strategy1)
self.key_w = Parameter(initializer(
"normal", [8, 16], ms.float32), name='key')
self.key = P.MatMul().set_strategy(strategy2)
self.key = P.MatMul().shard(strategy2)
self.value_w = Parameter(initializer(
"normal", [8, 16], ms.float32), name='value')
self.value = P.MatMul().set_strategy(strategy3)
self.value = P.MatMul().shard(strategy3)
self.score = P.MatMul().set_strategy(strategy4)
self.context = P.MatMul().set_strategy(strategy5)
self.score = P.MatMul().shard(strategy4)
self.context = P.MatMul().shard(strategy5)
self.transpose1 = P.Transpose()
self.transpose2 = P.Transpose()
self.relu = P.ReLU()

@ -24,8 +24,8 @@ from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().set_strategy(strategy1)
self.sigmoid = P.Sigmoid().set_strategy(strategy2)
self.mul = P.Mul().shard(strategy1)
self.sigmoid = P.Sigmoid().shard(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
def construct(self, x, b):

@ -107,7 +107,7 @@ def test_auto_parallel_arithmetic_model():
def __init__(self):
super().__init__()
self.matmul = P.MatMul()
self.one_hot = P.OneHot().set_strategy(((1, 8), (), ()))
self.one_hot = P.OneHot().shard(((1, 8), (), ()))
self.on_value = Tensor(1.0, ms.float32)
self.off_value = Tensor(0.0, ms.float32)
self.matmul2 = P.MatMul()

@ -53,7 +53,7 @@ def test_four_matmul_linear():
class Net(nn.Cell):
def __init__(self, strategy1):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul1 = P.MatMul().shard(strategy1)
self.matmul2 = P.MatMul()
self.matmul3 = P.MatMul()
self.matmul4 = P.MatMul()

@ -298,7 +298,7 @@ def test_reshape_auto_7():
def __init__(self):
super().__init__()
self.reshape = P.Reshape()
self.mul = P.Mul().set_strategy(((1, 2, 4), (2, 4)))
self.mul = P.Mul().shard(((1, 2, 4), (2, 4)))
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
def construct(self, x):

@ -53,7 +53,7 @@ def test_four_matmul_linear():
class Net(nn.Cell):
def __init__(self, strategy1):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul1 = P.MatMul().shard(strategy1)
self.weight = Parameter(Tensor(np.ones([512, 256]).astype(np.float32) * 0.01), "w", requires_grad=True)
self.matmul2 = P.MatMul()

@ -24,8 +24,8 @@ from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, mul_weight, batch_matmul_weight, transpose_b=False, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().set_strategy(strategy1)
self.batch_matmul = P.BatchMatMul(transpose_b=transpose_b).set_strategy(strategy2)
self.mul = P.Mul().shard(strategy1)
self.batch_matmul = P.BatchMatMul(transpose_b=transpose_b).shard(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
self.batch_matmul_weight = Parameter(batch_matmul_weight, "w2")

@ -73,7 +73,7 @@ class NetConv(nn.Cell):
has_bias,
weight_init,
bias_init)
self.conv.conv2d.set_strategy(strategy)
self.conv.conv2d.shard(strategy)
def construct(self, input_x):
return self.conv(input_x)
@ -84,9 +84,9 @@ def test_batch():
def __init__(self, strategy1, strategy2, strategy3):
super().__init__()
self.conv1 = NetConv(16, 8, (3, 3), bias_init='zeros', strategy=strategy1)
self.mul1 = P.Mul().set_strategy(strategy2)
self.mul1 = P.Mul().shard(strategy2)
self.conv2 = NetConv(8, 64, (9, 9), bias_init='zeros', strategy=strategy1)
self.mul2 = P.Mul().set_strategy(strategy3)
self.mul2 = P.Mul().shard(strategy3)
def construct(self, x, w1, w2):
out1 = self.conv1(x)

@ -64,7 +64,7 @@ def conv7x7(in_channels, out_channels, stride=1, padding=0):
conv = Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False,
pad_mode="same")
conv.conv2d.set_strategy(strategy_weight)
conv.conv2d.shard(strategy_weight)
return conv
@ -86,7 +86,7 @@ def bn_with_initialize(out_channels):
gamma = weight_variable_1(shape)
bn = BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma,
beta_init=beta, moving_mean_init=mean, moving_var_init=var)
bn.bn_train.set_strategy(strategy_bn)
bn.bn_train.shard(strategy_bn)
return bn
@ -98,10 +98,10 @@ class ResNet(Cell):
self.conv1 = conv7x7(3, 64, stride=2, padding=0)
self.bn1 = bn_with_initialize(64)
self.relu = ReLU()
self.relu.relu.set_strategy(strategy_no_weight)
self.relu.relu.shard(strategy_no_weight)
self.maxpool = MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.reshape = P.Reshape()
self.matmul = P.MatMul().set_strategy(((8, 1), (1, 1)))
self.matmul = P.MatMul().shard(((8, 1), (1, 1)))
self.matmul_weight = Parameter(Tensor(np.ones([200704, num_classes]), dtype=ms.float32), name="weight")
def construct(self, x):
@ -135,7 +135,7 @@ def test_batchnorm_batch_parallel():
net = batchnorm_net(num_classes)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(((dev_num, 1), (dev_num, 1)))
loss.softmax_cross_entropy.shard(((dev_num, 1), (dev_num, 1)))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
model = Model(net, loss, opt)

@ -51,13 +51,13 @@ def test_two_matmul_batchnorm_ex():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul1 = P.MatMul().shard(strategy1)
self.norm = P.FusedBatchNormEx()
self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma")
self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta")
self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean")
self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var")
self.matmul2 = P.MatMul().set_strategy(strategy2)
self.matmul2 = P.MatMul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul1(x, y)

@ -70,7 +70,7 @@ class Net(nn.Cell):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, pad_mode='valid',
has_bias=True, weight_init='ones', bias_init='ones')
self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(((1, 1, 1, 8),))
self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((1, 1, 1, 8),))
self.flat = nn.Flatten()
def construct(self, inputs):

@ -87,18 +87,18 @@ class FusedBatchNorm(nn.Cell):
epsilon=self.eps)
self.bn_infer = P.BatchNorm(is_training=False,
epsilon=self.eps)
self.sub_mean = P.Sub().set_strategy(((1), (1)))
self.sub_var = P.Sub().set_strategy(((1), (1)))
self.mul_mean = P.Mul().set_strategy(((1,), ()))
self.mul_var = P.Mul().set_strategy(((1,), ()))
self.assign_sub_mean = P.AssignSub().set_strategy(((1,), (1,)))
self.assign_sub_var = P.AssignSub().set_strategy(((1), (1)))
self.sub_mean2 = P.Sub().set_strategy(((1), (1)))
self.sub_var2 = P.Sub().set_strategy(((1), (1)))
def set_strategy(self, strategy):
self.bn_train.set_strategy(strategy)
self.bn_infer.set_strategy(strategy)
self.sub_mean = P.Sub().shard(((1), (1)))
self.sub_var = P.Sub().shard(((1), (1)))
self.mul_mean = P.Mul().shard(((1,), ()))
self.mul_var = P.Mul().shard(((1,), ()))
self.assign_sub_mean = P.AssignSub().shard(((1,), (1,)))
self.assign_sub_var = P.AssignSub().shard(((1), (1)))
self.sub_mean2 = P.Sub().shard(((1), (1)))
self.sub_var2 = P.Sub().shard(((1), (1)))
def shard(self, strategy):
self.bn_train.shard(strategy)
self.bn_infer.shard(strategy)
def _check_data_dim(self, x):
raise NotImplementedError
@ -173,7 +173,7 @@ class PReLU(nn.Cell):
w = Tensor(w)
self.w = Parameter(initializer(w, [channel,]), name='a')
self.prelu = P.PReLU()
self.relu = P.ReLU().set_strategy(((1)))
self.relu = P.ReLU().shard(((1)))
def construct(self, x):
self.w = self.relu(self.w)
@ -210,7 +210,7 @@ def bn_common(parallel_mode, train_flag, strategy_loss=None):
net = bn_net()
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(strategy_loss)
loss.softmax_cross_entropy.shard(strategy_loss)
opt = Momentum(net.trainable_params(), learning_rate, momentum, 0.0001, 1024 * rank_size)
if not train_flag:

@ -52,8 +52,8 @@ class CommonNet(nn.Cell):
def __init__(self):
super(CommonNet, self).__init__()
self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight")
self.logicalnot = P.LogicalNot().set_strategy(((4, 2),))
self.equal = P.Equal().set_strategy(((4, 2), (4, 2)))
self.logicalnot = P.LogicalNot().shard(((4, 2),))
self.equal = P.Equal().shard(((4, 2), (4, 2)))
def construct(self, x, label):
x = self.equal(x, self.weight)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save