remove the 'raise' in construct of Cell

pull/10035/head
buxue 4 years ago
parent aee1ff43ef
commit aeeef7607a

@ -315,6 +315,9 @@ class Cell(Cell_):
return tuple(res) return tuple(res)
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
f"will call the super class(Cell) 'construct'.")
if kwargs: if kwargs:
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
inputs = bound_args.args inputs = bound_args.args
@ -681,7 +684,7 @@ class Cell(Cell_):
Returns: Returns:
Tensor, returns the computed result. Tensor, returns the computed result.
""" """
raise NotImplementedError return None
def init_parameters_data(self, auto_parallel_mode=False): def init_parameters_data(self, auto_parallel_mode=False):
""" """

@ -197,8 +197,7 @@ def test_exceptions():
ModError2(t) ModError2(t)
m = nn.Cell() m = nn.Cell()
with pytest.raises(NotImplementedError): assert m.construct() is None
m.construct()
def test_cell_copy(): def test_cell_copy():

@ -63,9 +63,7 @@ def test_net_without_construct():
""" test_net_without_construct """ """ test_net_without_construct """
net = NetMissConstruct() net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp) _executor.compile(net, inp)
assert "Unsupported syntax 'Raise' at " in str(err.value)
class NetWithRaise(nn.Cell): class NetWithRaise(nn.Cell):

@ -196,6 +196,4 @@ def test_missing_construct():
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input) tensor = Tensor(np_input)
net = NetMissConstruct() net = NetMissConstruct()
with pytest.raises(RuntimeError) as er: assert net(tensor) is None
net(tensor)
assert "Unsupported syntax 'Raise' at " in str(er.value)

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
""" test super""" """ test super"""
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
@ -108,9 +107,7 @@ def test_super_cell():
net = Net(2) net = Net(2)
x = Tensor(np.ones([1, 2, 3], np.int32)) x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as er: assert net(x, y) is None
net(x, y)
assert "Unsupported syntax 'Raise'" in str(er.value)
def test_single_super_in(): def test_single_super_in():

@ -212,8 +212,7 @@ def test_exceptions():
ModError2(t) ModError2(t)
m = nn.Cell() m = nn.Cell()
with pytest.raises(NotImplementedError): assert m.construct() is None
m.construct()
def test_del(): def test_del():

Loading…
Cancel
Save