|
|
|
@ -19,21 +19,27 @@
|
|
|
|
|
@Desc :
|
|
|
|
|
"""
|
|
|
|
|
import logging
|
|
|
|
|
import pytest
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import mindspore as ms
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.common.api import ms_function, _executor
|
|
|
|
|
from mindspore.ops._grad.grad_base import bprop_getters
|
|
|
|
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
|
|
|
|
from mindspore.ops.functional import tensor_add
|
|
|
|
|
from ...ut_filter import non_graph_engine
|
|
|
|
|
|
|
|
|
|
# pylint: disable=W0613
|
|
|
|
|
# pylint: disable=W0613,W0612
|
|
|
|
|
# W0613: unused-argument
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger("test")
|
|
|
|
|
log.setLevel(level=logging.ERROR)
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test case: use the parse obj interface use default parameter
|
|
|
|
@ -135,3 +141,113 @@ def test_net_with_ndarray():
|
|
|
|
|
input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
|
|
|
|
|
|
|
|
|
|
net(ms.Tensor(input_data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bprop_with_wrong_output_num():
|
|
|
|
|
context.set_context(check_bprop=True)
|
|
|
|
|
class BpropWithWrongOutputNum(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
|
|
|
|
|
|
|
|
|
|
def __call__(self, x, y):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, yshape):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type, y_type):
|
|
|
|
|
return x_type
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(BpropWithWrongOutputNum)
|
|
|
|
|
def get_bprop_with_wrong_output_num(self):
|
|
|
|
|
"""Generate bprop for BpropWithWrongOutputNum"""
|
|
|
|
|
|
|
|
|
|
def bprop(x, y, out, dout):
|
|
|
|
|
return (dout,)
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
class BpropWithWrongOutputNumCell(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(BpropWithWrongOutputNumCell, self).__init__()
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
return BpropWithWrongOutputNum()(x, y)
|
|
|
|
|
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
|
|
|
|
|
|
|
|
|
|
def test_bprop_with_wrong_output_type():
|
|
|
|
|
context.set_context(check_bprop=True)
|
|
|
|
|
class BpropWithWrongOutputType(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
|
|
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type):
|
|
|
|
|
return x_type
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(BpropWithWrongOutputType)
|
|
|
|
|
def get_bprop_with_wrong_output_type(self):
|
|
|
|
|
"""Generate bprop for BpropWithWrongOutputType"""
|
|
|
|
|
|
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
|
return (1,)
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
class BpropWithWrongOutputTypeCell(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(BpropWithWrongOutputTypeCell, self).__init__()
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return BpropWithWrongOutputType()(x)
|
|
|
|
|
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bprop_with_wrong_output_shape():
|
|
|
|
|
context.set_context(check_bprop=True)
|
|
|
|
|
class BpropWithWrongOutputShape(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
|
|
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type):
|
|
|
|
|
return x_type
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(BpropWithWrongOutputShape)
|
|
|
|
|
def get_bprop_with_wrong_output_shape(self):
|
|
|
|
|
"""Generate bprop for BpropWithWrongOutputShape"""
|
|
|
|
|
ones = Tensor(np.ones([2,]).astype(np.int32))
|
|
|
|
|
|
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
|
return (ones,)
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
class BpropWithWrongOutputShapeCell(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(BpropWithWrongOutputShapeCell, self).__init__()
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return BpropWithWrongOutputShape()(x)
|
|
|
|
|
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
net = BpropWithWrongOutputShapeCell()
|
|
|
|
|
net.set_grad()
|
|
|
|
|
C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
|
|
|
|
|