|
|
|
@ -272,3 +272,32 @@ def test_cast_before_mirror3():
|
|
|
|
|
y = Tensor(np.ones([32, 64]), dtype=ms.float16)
|
|
|
|
|
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
|
|
|
_executor.compile(net, x, y, b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_mul_two_cast():
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
def __init__(self, strategy1, strategy2, strategy3):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.mul = P.Mul().set_strategy(strategy1)
|
|
|
|
|
self.mul2 = P.Mul().set_strategy(strategy2)
|
|
|
|
|
self.cast = P.Cast().set_strategy(strategy3)
|
|
|
|
|
self.cast2 = P.Cast().set_strategy(strategy3)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y, b):
|
|
|
|
|
out = self.mul(x, y)
|
|
|
|
|
out = self.mul2(out, b)
|
|
|
|
|
out = self.cast(out, ms.int32)
|
|
|
|
|
out = self.cast2(out, ms.bool_)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
|
|
|
|
strategy1 = ((2, 2), (2, 2))
|
|
|
|
|
strategy2 = ((8, 1), (8, 1))
|
|
|
|
|
strategy3 = ((8, 1), )
|
|
|
|
|
net = GradWrap(Net(strategy1, strategy2, strategy3))
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
|
|
|
|
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
|
|
|
|
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
|
|
|
|
_executor.compile(net, x, y, b)
|
|
|
|
|