|
|
|
@ -51,13 +51,13 @@ def test_two_matmul_batchnorm_ex():
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
def __init__(self, strategy1, strategy2):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.matmul1 = P.MatMul().shard(strategy1)
|
|
|
|
|
self.matmul1 = P.BatchMatMul().shard(strategy1)
|
|
|
|
|
self.norm = P.FusedBatchNormEx()
|
|
|
|
|
self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma")
|
|
|
|
|
self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta")
|
|
|
|
|
self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean")
|
|
|
|
|
self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var")
|
|
|
|
|
self.matmul2 = P.MatMul().shard(strategy2)
|
|
|
|
|
self.matmul2 = P.BatchMatMul().shard(strategy2)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y, b):
|
|
|
|
|
out = self.matmul1(x, y)
|
|
|
|
@ -66,12 +66,12 @@ def test_two_matmul_batchnorm_ex():
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8)
|
|
|
|
|
strategy1 = ((4, 2), (2, 1))
|
|
|
|
|
strategy2 = ((1, 8), (8, 1))
|
|
|
|
|
strategy1 = ((1, 1, 4, 2), (1, 1, 2, 1))
|
|
|
|
|
strategy2 = ((1, 1, 1, 8), (1, 1, 8, 1))
|
|
|
|
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
|
|
|
|
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
|
|
|
|
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
|
|
|
x = Tensor(np.ones([64, 64, 128, 32]), dtype=ms.float32)
|
|
|
|
|
y = Tensor(np.ones([64, 64, 32, 64]), dtype=ms.float32)
|
|
|
|
|
b = Tensor(np.ones([64, 64, 64, 64]), dtype=ms.float32)
|
|
|
|
|
net.set_train()
|
|
|
|
|
_executor.compile(net, x, y, b)
|
|
|
|
|