|
|
@ -15,23 +15,17 @@
|
|
|
|
""" test_framstruct """
|
|
|
|
""" test_framstruct """
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import mindspore as ms
|
|
|
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
|
|
|
from mindspore.nn.wrap.cell_wrapper import WithGradCell, WithLossCell
|
|
|
|
|
|
|
|
from ..ut_filter import non_graph_engine
|
|
|
|
from ..ut_filter import non_graph_engine
|
|
|
|
from ....mindspore_test_framework.utils.check_gradient import (
|
|
|
|
from ....mindspore_test_framework.utils.check_gradient import (
|
|
|
|
ms_function, check_jacobian, Tensor, NNGradChecker,
|
|
|
|
ms_function, check_jacobian, Tensor, NNGradChecker,
|
|
|
|
OperationGradChecker, check_gradient, ScalarGradChecker)
|
|
|
|
OperationGradChecker, check_gradient, ScalarGradChecker)
|
|
|
|
from ....mindspore_test_framework.utils.bprop_util import bprop
|
|
|
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
|
|
|
from mindspore.ops._grad.grad_base import bprop_getters
|
|
|
|
from mindspore.ops._grad.grad_base import bprop_getters
|
|
|
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
|
|
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
|
|
|
|
|
|
|
|
|
|
@ -674,7 +668,7 @@ def grad_refactor_6(a, b):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_refactor_6():
|
|
|
|
def test_grad_refactor_6():
|
|
|
|
C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
|
|
|
|
assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grad_refactor_while(x):
|
|
|
|
def grad_refactor_while(x):
|
|
|
|