|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
@ -33,7 +33,7 @@ grad_by_list = C.GradOperation(get_by_list=True)
|
|
|
|
|
grad_all = C.GradOperation(get_all=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_while_forward():
|
|
|
|
|
def test_while_grad():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -46,31 +46,71 @@ def test_while_forward():
|
|
|
|
|
x[idx, :, 0:2] = max_num
|
|
|
|
|
idx = idx + 1
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class GradNet(nn.Cell):
|
|
|
|
|
def __init__(self, net):
|
|
|
|
|
super(GradNet, self).__init__()
|
|
|
|
|
self.net = net
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
|
return grad_all(self.net)(*inputs)
|
|
|
|
|
# graph mode
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
|
|
|
net = MyWhileNet()
|
|
|
|
|
while_net = MyWhileNet()
|
|
|
|
|
net = GradNet(while_net)
|
|
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32)
|
|
|
|
|
end = Tensor(np.array(2), dtype=ms.int32)
|
|
|
|
|
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
|
|
|
|
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
|
|
|
|
graph_output = net(idx, end, x)
|
|
|
|
|
# pynative mode
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
|
|
|
pynative_output = net(idx, end, x)
|
|
|
|
|
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
def test_while_with_const_param_grad():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
self.add = P.Add()
|
|
|
|
|
|
|
|
|
|
def test_while_grad():
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
while x < y:
|
|
|
|
|
z = self.mul(x, x)
|
|
|
|
|
x = self.add(z, 1)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class GradNet(nn.Cell):
|
|
|
|
|
def __init__(self, net):
|
|
|
|
|
super(GradNet, self).__init__()
|
|
|
|
|
self.net = net
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
|
return grad_all(self.net)(*inputs)
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
|
|
|
while_net = MyWhileNet()
|
|
|
|
|
net = GradNet(while_net)
|
|
|
|
|
idx = Tensor([1.1], dtype=ms.float32)
|
|
|
|
|
end = Tensor([8.0], dtype=ms.float32)
|
|
|
|
|
graph_output = net(idx, end)
|
|
|
|
|
expect_one = np.array([1.14433983e+02], dtype=np.float32)
|
|
|
|
|
expect_two = np.array([0], dtype=np.float32)
|
|
|
|
|
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
def test_while_with_variable_grad():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.max = P.ReduceMax()
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
self.add = P.Add()
|
|
|
|
|
|
|
|
|
|
def construct(self, idx, end, x):
|
|
|
|
|
while idx < end:
|
|
|
|
|
part = x[idx, :, :]
|
|
|
|
|
max_num = self.max(part)
|
|
|
|
|
x[idx, :, 0:2] = max_num
|
|
|
|
|
idx = idx + 1
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
while x < y:
|
|
|
|
|
z = self.mul(x, x)
|
|
|
|
|
x = self.add(z, y)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class GradNet(nn.Cell):
|
|
|
|
@ -80,20 +120,16 @@ def test_while_grad():
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
|
return grad_all(self.net)(*inputs)
|
|
|
|
|
# graph mode
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
|
|
|
while_net = MyWhileNet()
|
|
|
|
|
net = GradNet(while_net)
|
|
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32)
|
|
|
|
|
end = Tensor(np.array(2), dtype=ms.int32)
|
|
|
|
|
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
|
|
|
|
graph_output = net(idx, end, x)
|
|
|
|
|
# pynative mode
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
|
|
|
pynative_output = net(idx, end, x)
|
|
|
|
|
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
idx = Tensor([1.1], dtype=ms.float32)
|
|
|
|
|
end = Tensor([8.0], dtype=ms.float32)
|
|
|
|
|
graph_output = net(idx, end)
|
|
|
|
|
expect_one = np.array([2.20000005e+00], dtype=np.float32)
|
|
|
|
|
expect_two = np.array([1.00000000e+00], dtype=np.float32)
|
|
|
|
|
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
def test_while_with_param_forward():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
@ -153,7 +189,6 @@ def test_while_endless_case():
|
|
|
|
|
pynative_output = net(idx, end, x)
|
|
|
|
|
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_while_with_param_grad():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -180,7 +215,6 @@ def test_while_with_param_grad():
|
|
|
|
|
|
|
|
|
|
def construct(self, a, b, c):
|
|
|
|
|
return grad_by_list(self.net, self.weights)(a, b, c)
|
|
|
|
|
# graph mode
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
|
|
|
while_net = MyWhileNet()
|
|
|
|
|
net = GradNet(while_net)
|
|
|
|
@ -188,10 +222,8 @@ def test_while_with_param_grad():
|
|
|
|
|
end = Tensor(np.array(2), dtype=ms.int32)
|
|
|
|
|
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
|
|
|
|
graph_output = net(idx, end, x)
|
|
|
|
|
# pynative mode
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
|
|
|
pynative_output = net(idx, end, x)
|
|
|
|
|
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32)
|
|
|
|
|
assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
def test_while_with_param_forward_with_const_branch():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|