remove the 'raise' in construct of Cell

pull/10035/head
buxue 4 years ago
parent aee1ff43ef
commit aeeef7607a

@ -315,6 +315,9 @@ class Cell(Cell_):
return tuple(res)
def __call__(self, *inputs, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
f"will call the super class(Cell) 'construct'.")
if kwargs:
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
inputs = bound_args.args
@ -681,7 +684,7 @@ class Cell(Cell_):
Returns:
Tensor, returns the computed result.
"""
raise NotImplementedError
return None
def init_parameters_data(self, auto_parallel_mode=False):
"""

@ -197,8 +197,7 @@ def test_exceptions():
ModError2(t)
m = nn.Cell()
with pytest.raises(NotImplementedError):
m.construct()
assert m.construct() is None
def test_cell_copy():

@ -63,9 +63,7 @@ def test_net_without_construct():
""" test_net_without_construct """
net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp)
assert "Unsupported syntax 'Raise' at " in str(err.value)
class NetWithRaise(nn.Cell):

@ -196,6 +196,4 @@ def test_missing_construct():
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetMissConstruct()
with pytest.raises(RuntimeError) as er:
net(tensor)
assert "Unsupported syntax 'Raise' at " in str(er.value)
assert net(tensor) is None

@ -14,7 +14,6 @@
# ============================================================================
""" test super"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
@ -108,9 +107,7 @@ def test_super_cell():
net = Net(2)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as er:
net(x, y)
assert "Unsupported syntax 'Raise'" in str(er.value)
assert net(x, y) is None
def test_single_super_in():

@ -212,8 +212,7 @@ def test_exceptions():
ModError2(t)
m = nn.Cell()
with pytest.raises(NotImplementedError):
m.construct()
assert m.construct() is None
def test_del():

Loading…
Cancel
Save