|
|
|
@ -45,11 +45,11 @@ class GradWrap(nn.Cell):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None):
|
|
|
|
|
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
|
|
|
|
|
super().__init__()
|
|
|
|
|
if shape is None:
|
|
|
|
|
shape = [64, 64]
|
|
|
|
|
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
|
|
|
|
self.gatherv2 = P.GatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target)
|
|
|
|
|
self.mul = P.Mul().set_strategy(strategy2)
|
|
|
|
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
|
|
|
|
self.axis = axis
|
|
|
|
@ -188,7 +188,7 @@ def test_gatherv2_cpu0():
|
|
|
|
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
|
|
|
|
strategy1 = ((8, 1), (1, 1))
|
|
|
|
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
|
|
|
|
net = NetWithLoss(Net(0, strategy1, strategy2))
|
|
|
|
|
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
|
|
@ -200,7 +200,7 @@ def test_gatherv2_cpu1():
|
|
|
|
|
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
|
|
|
|
strategy1 = ((16, 1), (1, 1))
|
|
|
|
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
|
|
|
|
net = NetWithLoss(Net(0, strategy1, strategy2))
|
|
|
|
|
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
|
|
@ -212,7 +212,7 @@ def test_gatherv2_cpu2():
|
|
|
|
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
|
|
|
|
strategy1 = ((1, 8), (1, 1))
|
|
|
|
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
|
|
|
|
net = NetWithLoss(Net(0, strategy1, strategy2))
|
|
|
|
|
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
|
|
|