|
|
@ -15,23 +15,17 @@
|
|
|
|
""" test_framstruct """
|
|
|
|
""" test_framstruct """
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import mindspore as ms
|
|
|
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
|
|
|
from mindspore.nn.wrap.cell_wrapper import WithGradCell, WithLossCell
|
|
|
|
|
|
|
|
from ..ut_filter import non_graph_engine
|
|
|
|
from ..ut_filter import non_graph_engine
|
|
|
|
from ....mindspore_test_framework.utils.check_gradient import (
|
|
|
|
from ....mindspore_test_framework.utils.check_gradient import (
|
|
|
|
ms_function, check_jacobian, Tensor, NNGradChecker,
|
|
|
|
ms_function, check_jacobian, Tensor, NNGradChecker,
|
|
|
|
OperationGradChecker, check_gradient, ScalarGradChecker)
|
|
|
|
OperationGradChecker, check_gradient, ScalarGradChecker)
|
|
|
|
from ....mindspore_test_framework.utils.bprop_util import bprop
|
|
|
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
|
|
|
from mindspore.ops._grad.grad_base import bprop_getters
|
|
|
|
from mindspore.ops._grad.grad_base import bprop_getters
|
|
|
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
|
|
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
|
|
|
|
|
|
|
|
|
|
@ -299,22 +293,22 @@ def test_dont_unroll_while():
|
|
|
|
assert res == 3
|
|
|
|
assert res == 3
|
|
|
|
|
|
|
|
|
|
|
|
class ConvNet(nn.Cell):
|
|
|
|
class ConvNet(nn.Cell):
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
super(ConvNet, self).__init__()
|
|
|
|
super(ConvNet, self).__init__()
|
|
|
|
out_channel = 16
|
|
|
|
out_channel = 16
|
|
|
|
kernel_size = 3
|
|
|
|
kernel_size = 3
|
|
|
|
self.conv = P.Conv2D(out_channel,
|
|
|
|
self.conv = P.Conv2D(out_channel,
|
|
|
|
kernel_size,
|
|
|
|
kernel_size,
|
|
|
|
mode=1,
|
|
|
|
mode=1,
|
|
|
|
pad_mode="pad",
|
|
|
|
pad_mode="pad",
|
|
|
|
pad=0,
|
|
|
|
pad=0,
|
|
|
|
stride=1,
|
|
|
|
stride=1,
|
|
|
|
dilation=2,
|
|
|
|
dilation=2,
|
|
|
|
group=1)
|
|
|
|
group=1)
|
|
|
|
self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
|
|
|
|
self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
def construct(self, x):
|
|
|
|
return self.conv(x, self.w)
|
|
|
|
return self.conv(x, self.w)
|
|
|
|
|
|
|
|
|
|
|
|
conv = ConvNet()
|
|
|
|
conv = ConvNet()
|
|
|
|
c1 = Tensor([2], mstype.float32)
|
|
|
|
c1 = Tensor([2], mstype.float32)
|
|
|
@ -674,7 +668,7 @@ def grad_refactor_6(a, b):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_refactor_6():
|
|
|
|
def test_grad_refactor_6():
|
|
|
|
C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
|
|
|
|
assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grad_refactor_while(x):
|
|
|
|
def grad_refactor_while(x):
|
|
|
|