change dimension of input for FusedBatchNormEx from 2D to 4D in test_two_matmul_batchnorm_ex.

pull/12608/head
wangshuide2020 4 years ago
parent 5b206557d2
commit 72e938eb06

@ -51,13 +51,13 @@ def test_two_matmul_batchnorm_ex():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul1 = P.MatMul().shard(strategy1)
self.matmul1 = P.BatchMatMul().shard(strategy1)
self.norm = P.FusedBatchNormEx()
self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma")
self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta")
self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean")
self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var")
self.matmul2 = P.MatMul().shard(strategy2)
self.matmul2 = P.BatchMatMul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul1(x, y)
@ -66,12 +66,12 @@ def test_two_matmul_batchnorm_ex():
return out
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8)
strategy1 = ((4, 2), (2, 1))
strategy2 = ((1, 8), (8, 1))
strategy1 = ((1, 1, 4, 2), (1, 1, 2, 1))
strategy2 = ((1, 1, 1, 8), (1, 1, 8, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
x = Tensor(np.ones([64, 64, 128, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y, b)

Loading…
Cancel
Save