# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ test_framstruct """ import pytest import numpy as np import mindspore as ms import mindspore.nn as nn from mindspore import context from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer from mindspore.common import dtype as mstype import mindspore.nn as nn from mindspore.nn.wrap.cell_wrapper import WithGradCell, WithLossCell from ..ut_filter import non_graph_engine from ....mindspore_test_framework.utils.check_gradient import ( ms_function, check_jacobian, Tensor, NNGradChecker, OperationGradChecker, check_gradient, ScalarGradChecker) from ....mindspore_test_framework.utils.bprop_util import bprop import mindspore.context as context from mindspore.ops._grad.grad_base import bprop_getters from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer def setup_module(module): context.set_context(mode=context.PYNATIVE_MODE) @ms_function def while_upper_bound(upper): rval = 2 while rval < upper: rval = rval * rval return rval def test_while_upper_bound(): res = while_upper_bound(10) assert res == 16 @ms_function def while_lower_bound(lower): """ t_while """ rval = lower while rval < 100: rval = rval * rval return rval def test_while_lower_bound(): res = while_lower_bound(2) assert res == 256 @ms_function def dynamic_make_tuple(x, lower, upper): out = () i = lower while i < upper: out = out + (x,) i = i + 1 return out def test_dynamic_make_tuple(): # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language. with pytest.raises(RuntimeError): dynamic_make_tuple(2, 1, 5) def test_make_tuple(): # Staticly recursively creating static type is valid in mindspore. @ms_function def make_tuple(x): out = () for i in range(3): out = out + (x,) return out res = make_tuple(5) assert res == (5, 5, 5) @ms_function def add(x, y): """ add """ return x + y def mul(x, y): """ mul """ return x * y def add_mul(x, y): """ add_mul """ return (x + y) * y def mainf(x, y): """ mainf """ return C.grad_all(mul)(x, y) def grad_add_mul(x, y): """ grad_add_mul """ return C.grad_all(add_mul)(x, y) @ms_function def sub(x, y): """ sub """ return x - y @ms_function def if_always_true(x): """ if_always_true """ if True: return x else: return 0 def test_add(): """ test_add """ res = add(2.5, 3) assert res == 5.5 def test_sub(): """ test_sub """ res = sub(3.5, 3) assert res == 0.5 @non_graph_engine def test_if_always_true(): """ test_if_always_true """ res = if_always_true(1) assert res == 1 @non_graph_engine def test_f(): """ test_f """ res = mainf(3, 2) assert res == (2, 3) @non_graph_engine def test_grad_add_mul(): """ test_grad_add_mul """ res = grad_add_mul(3, 2) assert res == (2, 7) def f(x): if x > 0: return f(x-1) return x @ms_function def list_subscript(): """ list_subscript """ x= [1, 2, 3] return x[0] * x[1] def test_list_subscript(): """ test_list_subscript """ res = list_subscript() assert res == 2 @ms_function def ms_infer_for(xs, y): """ ms_infer_for """ rval = y for x in xs: rval = rval + x return rval def test_infer_for(): """ test_infer_for """ t = (1, 2, 3) y = 4 res = ms_infer_for(t, y) assert res == 10 @ms_function def if_construct(a, b): z = a if a > b: z = a+b else: z = a*b if z > b: return z-a else: return a-b def test_if_construct(): """ test_if_construct """ res = if_construct(3, 6) assert res == 15 @ms_function def if_scalar(a, b): """ if_abstract """ if a: return a return b def test_if_scalar1(): """ test_if_abstract """ res = if_scalar(3, 6) assert res == 3 def test_if_scalar2(): """ test_if_abstract """ res = if_scalar(0, 6) assert res == 6 @ms_function def if_tensor(a, b): c = a if a < b: c = a+a if c < b: c = a+c else: c = a+b else: c = b+b out = c + c return out def test_if_tensor(): res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32))) assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4) @ms_function def rec(x): """ rec """ if x > 0: return rec(x-1) return x def test_grad_rec(): """ test_grad_rec """ res = C.grad(rec)(10) assert res == 1 def test_me_rec(): """ test_me_rec """ res = rec(10) assert res == 0 @ms_function def t2_while(x, y): out = y - x i = 0 while i < 10: out = mul(x, y) i = i + 1 return out def test_while2(): res = t2_while(2, 3) assert res == 6 def test_grad_while2(): res = C.grad(t2_while)(2, 3) assert res == 3 def if_test(a, b): """ if_test """ if a > b: return 3 * a return 2 * b def grad_if(x, y): """ grad_if """ return C.grad_all(if_test)(x, y) def test_grad_if(): """ test_grad_if """ assert grad_if(5, 4) == (3, 0) # While loop is not unrolled in forward and backward graphs. def test_dont_unroll_while(): def dont_unroll_while(x, y): i = 2 out = y - x while i < 10: out = mul(x, y) i = i + 1 return out @ms_function() def invoke_while(x, y): return C.grad(dont_unroll_while)(x, y) res = invoke_while(2, 3) assert res == 3 class ConvNet(nn.Cell): def __init__(self): super(ConvNet, self).__init__() out_channel = 16 kernel_size = 3 self.conv = P.Conv2D(out_channel, kernel_size, mode=1, pad_mode="pad", pad=0, stride=1, dilation=2, group=1) self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w') def construct(self, x): return self.conv(x, self.w) conv = ConvNet() c1 = Tensor([2], mstype.float32) c2 = Tensor([10], mstype.float32) c3 = Tensor([1], mstype.float32) @ms_function def t1_while(x, y, z): out = x i = c1 while i < c2: out = out + conv(z) i = i + c3 out = out + out return out def test_while_net(): y = Tensor(np.ones([1,3,3,4]).astype(np.float32)) x = Tensor(np.ones([1,16,12,12]).astype(np.float32)) z = Tensor(np.ones([1,16,16,16]).astype(np.float32)) res = t1_while(x, y, z) assert res == Tensor(np.ones([1,16,12,12]).astype(np.float32) * 2306.0) @ms_function def if_while(a, b, x, z): c = a i = c1 out = x if a < b: c = a+a while i < c2: out = out + conv(z) i = i + c3 else: c = b+b out = c + c return out def test_if_while(): x = Tensor(np.random.randn(1,16,12,12).astype(np.float32)) z = Tensor(np.random.randn(1,16,16,16).astype(np.float32)) res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z) assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0) def _while(x): """ _while """ ret = x * x i = 2 while i <= 3: ret = ret * i i = i + 1 return ret def grad_while(x): """ grad_while """ return C.grad_all(_while)(x) def test_grad_while(): """ test_grad_while """ assert grad_while(5) == (60,) @ms_function def factorial(n): """ factorial """ if n == 0: return 1 return n * factorial(n-1) def test_factorial(): res = factorial(3) assert res == 6 def test_grad_factorial(): res = C.grad(factorial)(3) assert res == 11 @ms_function def factorial2(n): """ factorial """ if n != 0: return n * factorial2(n-1) elif n == 1: return 1 * factorial2(n-1) else: return 1 def test_factorial2(): res = factorial2(3) assert res == 6 @ms_function def foo(n): if n <= 1: if n == 1: return foo(n-1) else: return 1 else: return foo(n-1) def test_foo(): res = foo(5) assert res == 1 @ms_function def double_nested_loop(x): i = 0 s = 0 while(i < x): j = 0 i = i + 1 while(j < 3): j = j + 1 s = s + j return s def test_nested_loop(): res = double_nested_loop(3) assert res == 18 @ms_function def double_nested_loop2(x): s = 0 for i in range(x): for j in range(3): s = s + j return s def test_nested_loop2(): res = double_nested_loop(1) assert res == 6 def _for(x): """ _for """ ret = x * x for i in (2, 3): ret = ret * i return ret def grad_for(x): """ grad_for """ return C.grad_all(_for)(x) def test_grad_for(): """ test_grad_for """ assert grad_for(5) == (60,) @ms_function def try_tail(x): """ try_tail """ return C.tail(x) @non_graph_engine def test_tail(): """ test_tail """ try_tail((0, 1, 2, 3)) @ms_function def zero_like_tensor(x): """ zero_like_tensor """ return C.zeros_like(x) def test_zeros(): """ test_zeros """ x = Tensor(np.ones([2, 3]).astype(np.int32)) res = zero_like_tensor(x) assert res == Tensor(np.zeros([2, 3]).astype(np.int32)) def test_ScalarGradChecker(): """ test_ScalarGradChecker """ def scalar_f(x, y): return x * y check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1) def test_GradCheckerPrimitive(): """ test_GradCheckerPrimitive """ matmul = P.MatMul() def prim_f(x, y): return matmul(x, y) check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)), Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker, sampling_times=2) def test_NNGradChecker(): """ test_NNGradChecker """ class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() self.dense = nn.Dense(10, 10) def construct(self, x): out = self.dense(x) return out check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)), delta=1e-3, max_error=1e-3, grad_checker_class=NNGradChecker, sampling_times=3) def test_OperationGradChecker(): """ test_OperationGradChecker """ class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() self.matmul = P.MatMul() self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') def construct(self, x, y): x = x * self.z out = self.matmul(x, y) return out check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)), Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker, input_selector=[1], sampling_times=2) def test_ScalarJacobianChecker(): """ test_ScalarJacobianChecker """ def scalar_f(x, y): return x * y check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0]) def test_OperationJacobianChecker(): """ test_OperationJacobianChecker """ class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() self.matmul = P.MatMul() self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') def construct(self, x, y): x = x * self.z out = self.matmul(x, y) return x, out check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)), Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)), grad_checker_class=OperationGradChecker, input_selector=[0], output_selector=[0]) def test_NNJacobianChecker(): """ test_NNJacobianChecker """ class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() self.dense = nn.Dense(10, 10) def construct(self, x): out = self.dense(x) return out, x check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)), delta=1e-3, max_error=1e-7, grad_checker_class=NNGradChecker, input_selector=[1], output_selector=[0]) def multi_outputs(x, y): z = x + y return 2 * z, 2 * z def test_grad_multi_outputs(): assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4) @ms_function def while_sp(x, y, z): out = x i = c3 while i < c2: out = mul(x, out) i = i + c3 return out def test_while_sp(): y = Tensor(np.ones([1, 3]).astype(np.float32)) z = Tensor(np.ones([1, 3]).astype(np.float32)) x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0) res = while_sp(x, y, z) assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0) def grad_refactor_simple_1(x, y): """ add """ return x * x + 2 * y def test_grad_refactor_simple_1(): assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2) def grad_refactor_simple_2(x, y, z): """ add """ return x * y + z + x * y * z + x + x * y def test_grad_refactor_simple_2(): assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7) def grad_refactor_1(a, b): """ if_test """ def inner(x, y): return x * y return inner(a, b) def test_grad_refactor_1(): assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2) def grad_refactor_2(a, b): """ if_test """ def inner(x): return x * b return inner(b) * inner(a) def test_grad_refactor_2(): assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54) def grad_refactor_3(a): """ if_test """ if a > 3: return 0 return 3 * a def test_grad_refactor_3(): assert C.grad_all(grad_refactor_3)(3) == (3,) def grad_refactor_4(a): """ if_test """ if a > 3: return 3 * a return 0 def test_grad_refactor_4(): assert C.grad_all(grad_refactor_4)(4) == (3,) def grad_refactor_5(a): """ if_test """ if a > 3: return 1 return a def test_grad_refactor_5(): assert C.grad_all(grad_refactor_5)(1) == (1,) def grad_refactor_6(a, b): """ if_test """ if a > b: return 3 * a + b return 2 * b * a def test_grad_refactor_6(): C.grad_all(grad_refactor_6)(3, 2) == (3, 1) def grad_refactor_while(x): """ grad_refactor_while """ rval = x while rval < 4: rval = rval * rval return rval def test_grad_refactor_9(): assert C.grad_all(grad_refactor_while)(3) == (6,) def grad_refactor__while_1(x): """ _while """ ret = x * x i = 2 while i <= 3: ret = ret * i i = i + 1 return ret def test_grad_refactor_10(): """ test_grad_while """ assert C.grad_all(grad_refactor__while_1)(5) == (60,) def test_grad_refactor_11(): class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() def construct(self, x, y): return x * y * y net = Net() C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32))) def test_grad_refactor_12(): class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') def construct(self, x, y): return x * self.z * y net = Net() C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32))) def test_grad_refactor_13(): class Net(nn.Cell): """ Net definition """ def __init__(self): super(Net, self).__init__() self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z') def construct(self, x, y): return x * self.z * y net = Net() weights = ParameterTuple(net.trainable_params()) C.grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32))) def grad_refactor_14(a, b): """ if_test """ def inner1(x): return x * b def inner2(x): return a * b def inner3(x): if (x > 2): return a return b return inner1(b) + inner2(a) + inner3(a) def test_grad_refactor_14(): assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) class IfDeferInline(nn.Cell): def __init__(self, mul_size): super().__init__() self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32)) self.mul = P.Mul() def construct(self, inputs): x = self.mul(inputs, self.mul_weight) if True: x = x return x def test_grad_if_defer_inline(): """ test_grad_if_defer_inline """ network = IfDeferInline([128, 96]) network.add_flags(defer_inline=False) inp = Tensor(np.ones([128, 96]).astype(np.float32)) grads = C.grad_all(network)(inp) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) def test_bprop_with_wrong_output_num(): class BpropWithWrongOutputNum(PrimitiveWithInfer): @prim_attr_register def __init__(self): super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') def __call__(self, x, y): return x def infer_shape(self, x_shape, yshape): return x_shape def infer_dtype(self, x_type, y_type): return x_type @bprop_getters.register(BpropWithWrongOutputNum) def get_bprop_with_wrong_output_num(self): """Generate bprop for BpropWithWrongOutputNum""" def bprop(x, y, out, dout): return (dout,) return bprop class BpropWithWrongOutputNumCell(nn.Cell): def __init__(self): super(BpropWithWrongOutputNumCell, self).__init__() def construct(self, x, y): return BpropWithWrongOutputNum()(x, y) with pytest.raises(TypeError): C.grad_all(BpropWithWrongOutputNumCell())(1, 2) def test_bprop_with_wrong_output_type(): class BpropWithWrongOutputType(PrimitiveWithInfer): @prim_attr_register def __init__(self): super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') def __call__(self, x): return x def infer_shape(self, x_shape): return x_shape def infer_dtype(self, x_type): return x_type @bprop_getters.register(BpropWithWrongOutputType) def get_bprop_with_wrong_output_type(self): """Generate bprop for BpropWithWrongOutputType""" def bprop(x, out, dout): return (1,) return bprop class BpropWithWrongOutputTypeCell(nn.Cell): def __init__(self): super(BpropWithWrongOutputTypeCell, self).__init__() def construct(self, x): return BpropWithWrongOutputType()(x) with pytest.raises(TypeError): C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) def test_bprop_with_wrong_output_shape(): class BpropWithWrongOutputShape(PrimitiveWithInfer): @prim_attr_register def __init__(self): super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') def __call__(self, x): return x def infer_shape(self, x_shape): return x_shape def infer_dtype(self, x_type): return x_type @bprop_getters.register(BpropWithWrongOutputShape) def get_bprop_with_wrong_output_shape(self): """Generate bprop for BpropWithWrongOutputShape""" ones = Tensor(np.ones([2,]).astype(np.int32)) def bprop(x, out, dout): return (ones,) return bprop class BpropWithWrongOutputShapeCell(nn.Cell): def __init__(self): super(BpropWithWrongOutputShapeCell, self).__init__() def construct(self, x): return BpropWithWrongOutputShape()(x) with pytest.raises(TypeError): C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))