|
|
@ -14,31 +14,28 @@
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
""" test_framstruct """
|
|
|
|
""" test_framstruct """
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import mindspore as ms
|
|
|
|
import mindspore as ms
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from ..ut_filter import non_graph_engine
|
|
|
|
from ..ut_filter import non_graph_engine
|
|
|
|
from ....mindspore_test_framework.utils.check_gradient import (
|
|
|
|
from ....mindspore_test_framework.utils.check_gradient import (
|
|
|
|
ms_function, check_jacobian, Tensor, NNGradChecker,
|
|
|
|
ms_function, check_jacobian, Tensor, NNGradChecker,
|
|
|
|
OperationGradChecker, check_gradient, ScalarGradChecker)
|
|
|
|
OperationGradChecker, check_gradient)
|
|
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE)
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_module(module):
|
|
|
|
def setup_module(module):
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE)
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad = C.GradOperation()
|
|
|
|
|
|
|
|
grad_all = C.GradOperation(get_all=True)
|
|
|
|
grad_all = C.GradOperation(get_all=True)
|
|
|
|
grad_by_list = C.GradOperation(get_by_list=True)
|
|
|
|
grad_by_list = C.GradOperation(get_by_list=True)
|
|
|
|
grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
@ms_function
|
|
|
@ -79,9 +76,7 @@ def dynamic_make_tuple(x, lower, upper):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dynamic_make_tuple():
|
|
|
|
def test_dynamic_make_tuple():
|
|
|
|
# Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language.
|
|
|
|
assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
|
|
|
|
dynamic_make_tuple(2, 1, 5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_make_tuple():
|
|
|
|
def test_make_tuple():
|
|
|
@ -273,15 +268,6 @@ def rec(x):
|
|
|
|
return rec(x - 1)
|
|
|
|
return rec(x - 1)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def grad_rec(input_x):
|
|
|
|
|
|
|
|
return grad(rec)(input_x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_rec():
|
|
|
|
|
|
|
|
""" test_grad_rec """
|
|
|
|
|
|
|
|
res = grad_rec(3)
|
|
|
|
|
|
|
|
assert res == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_me_rec():
|
|
|
|
def test_me_rec():
|
|
|
|
""" test_me_rec """
|
|
|
|
""" test_me_rec """
|
|
|
@ -303,13 +289,6 @@ def test_while2():
|
|
|
|
assert res == 6
|
|
|
|
assert res == 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_while2():
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def df_t2_while(input_x, input_y):
|
|
|
|
|
|
|
|
return grad(t2_while)(input_x, input_y)
|
|
|
|
|
|
|
|
assert df_t2_while(2, 3) == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def if_test(a, b):
|
|
|
|
def if_test(a, b):
|
|
|
|
""" if_test """
|
|
|
|
""" if_test """
|
|
|
|
if a > b:
|
|
|
|
if a > b:
|
|
|
@ -327,24 +306,6 @@ def test_grad_if():
|
|
|
|
assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0)
|
|
|
|
assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# While loop is not unrolled in forward and backward graphs.
|
|
|
|
|
|
|
|
def test_dont_unroll_while():
|
|
|
|
|
|
|
|
def dont_unroll_while(x, y):
|
|
|
|
|
|
|
|
i = 2
|
|
|
|
|
|
|
|
out = y - x
|
|
|
|
|
|
|
|
while i < 10:
|
|
|
|
|
|
|
|
out = mul(x, y)
|
|
|
|
|
|
|
|
i = i + 1
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ms_function()
|
|
|
|
|
|
|
|
def invoke_while(x, y):
|
|
|
|
|
|
|
|
return grad(dont_unroll_while)(x, y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = invoke_while(2, 3)
|
|
|
|
|
|
|
|
assert res == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvNet(nn.Cell):
|
|
|
|
class ConvNet(nn.Cell):
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
super(ConvNet, self).__init__()
|
|
|
|
super(ConvNet, self).__init__()
|
|
|
@ -445,13 +406,6 @@ def test_factorial():
|
|
|
|
assert res == 6
|
|
|
|
assert res == 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_factorial():
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def df_factorial(x):
|
|
|
|
|
|
|
|
return grad(factorial)(x)
|
|
|
|
|
|
|
|
assert df_factorial(3) == 11
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
@ms_function
|
|
|
|
def factorial2(n):
|
|
|
|
def factorial2(n):
|
|
|
|
""" factorial """
|
|
|
|
""" factorial """
|
|
|
@ -523,17 +477,13 @@ def _for(x):
|
|
|
|
ret = ret * i
|
|
|
|
ret = ret * i
|
|
|
|
return ret
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
@ms_function
|
|
|
|
def grad_for(x):
|
|
|
|
def grad_for(x):
|
|
|
|
""" grad_for """
|
|
|
|
""" grad_for """
|
|
|
|
return grad_all(_for)(x)
|
|
|
|
return grad_all(_for)(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_for():
|
|
|
|
|
|
|
|
""" test_grad_for """
|
|
|
|
|
|
|
|
assert grad_for(5) == (60,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
@ms_function
|
|
|
|
def try_tail(x):
|
|
|
|
def try_tail(x):
|
|
|
|
""" try_tail """
|
|
|
|
""" try_tail """
|
|
|
@ -675,15 +625,6 @@ def test_arithmetic_simplify_08():
|
|
|
|
assert np.all(res.asnumpy() == expect)
|
|
|
|
assert np.all(res.asnumpy() == expect)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_ScalarGradChecker():
|
|
|
|
|
|
|
|
""" test_ScalarGradChecker """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_f(x, y):
|
|
|
|
|
|
|
|
return x * y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_GradCheckerPrimitive():
|
|
|
|
def test_GradCheckerPrimitive():
|
|
|
|
""" test_GradCheckerPrimitive """
|
|
|
|
""" test_GradCheckerPrimitive """
|
|
|
|
matmul = P.MatMul()
|
|
|
|
matmul = P.MatMul()
|
|
|
@ -737,15 +678,6 @@ def test_OperationGradChecker():
|
|
|
|
input_selector=[1], sampling_times=2)
|
|
|
|
input_selector=[1], sampling_times=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_ScalarJacobianChecker():
|
|
|
|
|
|
|
|
""" test_ScalarJacobianChecker """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_f(x, y):
|
|
|
|
|
|
|
|
return x * y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_OperationJacobianChecker():
|
|
|
|
def test_OperationJacobianChecker():
|
|
|
|
""" test_OperationJacobianChecker """
|
|
|
|
""" test_OperationJacobianChecker """
|
|
|
|
|
|
|
|
|
|
|
@ -795,13 +727,6 @@ def multi_outputs(x, y):
|
|
|
|
return 2 * z, 2 * z
|
|
|
|
return 2 * z, 2 * z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_multi_outputs():
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def df_multi_outputs(x, y):
|
|
|
|
|
|
|
|
return grad_all_with_sens(multi_outputs)(x, y, (1, 1))
|
|
|
|
|
|
|
|
assert df_multi_outputs(2, 3) == (4, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
@ms_function
|
|
|
|
def while_sp(x, y, z):
|
|
|
|
def while_sp(x, y, z):
|
|
|
|
out = x
|
|
|
|
out = x
|
|
|
@ -874,13 +799,6 @@ def grad_refactor_3(a):
|
|
|
|
return 3 * a
|
|
|
|
return 3 * a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_refactor_3():
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def df_refactor_3(x):
|
|
|
|
|
|
|
|
return grad_all(grad_refactor_3)(x)
|
|
|
|
|
|
|
|
assert df_refactor_3(3) == (3,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grad_refactor_4(a):
|
|
|
|
def grad_refactor_4(a):
|
|
|
|
""" if_test """
|
|
|
|
""" if_test """
|
|
|
|
if a > 3:
|
|
|
|
if a > 3:
|
|
|
@ -899,13 +817,6 @@ def grad_refactor_5(a):
|
|
|
|
return a
|
|
|
|
return a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_refactor_5():
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def df_refactor_5(x):
|
|
|
|
|
|
|
|
return grad_all(grad_refactor_5)(x)
|
|
|
|
|
|
|
|
assert df_refactor_5(1) == (1,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grad_refactor_6(a, b):
|
|
|
|
def grad_refactor_6(a, b):
|
|
|
|
""" if_test """
|
|
|
|
""" if_test """
|
|
|
|
if a > b:
|
|
|
|
if a > b:
|
|
|
@ -925,13 +836,6 @@ def grad_refactor_while(x):
|
|
|
|
return rval
|
|
|
|
return rval
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_refactor_9():
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def df_refactor_while(input_x):
|
|
|
|
|
|
|
|
return grad_all(grad_refactor_while)(input_x)
|
|
|
|
|
|
|
|
assert df_refactor_while(3) == (6,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grad_refactor__while_1(x):
|
|
|
|
def grad_refactor__while_1(x):
|
|
|
|
""" _while """
|
|
|
|
""" _while """
|
|
|
|
ret = x * x
|
|
|
|
ret = x * x
|
|
|
@ -1009,13 +913,6 @@ def grad_refactor_14(a, b):
|
|
|
|
return inner1(b) + inner2(a) + inner3(a)
|
|
|
|
return inner1(b) + inner2(a) + inner3(a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_refactor_14():
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
|
|
|
def df_refactor_14(x, y):
|
|
|
|
|
|
|
|
return grad_all(grad_refactor_14)(x, y)
|
|
|
|
|
|
|
|
assert df_refactor_14(2, 3) == (3, 9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# pylint: disable=using-constant-test
|
|
|
|
# pylint: disable=using-constant-test
|
|
|
|
class IfDeferInline(nn.Cell):
|
|
|
|
class IfDeferInline(nn.Cell):
|
|
|
|
def __init__(self, mul_size):
|
|
|
|
def __init__(self, mul_size):
|
|
|
@ -1044,6 +941,8 @@ def test_dict_const():
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
super(Net, self).__init__()
|
|
|
|
super(Net, self).__init__()
|
|
|
|
self.res = {'1': 10}
|
|
|
|
self.res = {'1': 10}
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self):
|
|
|
|
def construct(self):
|
|
|
|
return self.res
|
|
|
|
return self.res
|
|
|
|
|
|
|
|
|
|
|
|
Net()()
|
|
|
|
Net()()
|
|
|
|