diff --git a/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py b/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py index ee91d56097..5c5dd090e1 100644 --- a/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py +++ b/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py @@ -51,13 +51,13 @@ def test_two_matmul_batchnorm_ex(): class Net(nn.Cell): def __init__(self, strategy1, strategy2): super().__init__() - self.matmul1 = P.MatMul().shard(strategy1) + self.matmul1 = P.BatchMatMul().shard(strategy1) self.norm = P.FusedBatchNormEx() self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma") self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta") self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean") self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var") - self.matmul2 = P.MatMul().shard(strategy2) + self.matmul2 = P.BatchMatMul().shard(strategy2) def construct(self, x, y, b): out = self.matmul1(x, y) @@ -66,12 +66,12 @@ def test_two_matmul_batchnorm_ex(): return out context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8) - strategy1 = ((4, 2), (2, 1)) - strategy2 = ((1, 8), (8, 1)) + strategy1 = ((1, 1, 4, 2), (1, 1, 2, 1)) + strategy2 = ((1, 1, 1, 8), (1, 1, 8, 1)) net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) net.set_auto_parallel() - x = Tensor(np.ones([128, 32]), dtype=ms.float32) - y = Tensor(np.ones([32, 64]), dtype=ms.float32) - b = Tensor(np.ones([64, 64]), dtype=ms.float32) + x = Tensor(np.ones([64, 64, 128, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32, 64]), dtype=ms.float32) + b = Tensor(np.ones([64, 64, 64, 64]), dtype=ms.float32) net.set_train() _executor.compile(net, x, y, b)