optimize error info when assign parameter to non-parameter attr

pull/9761/head
buxue 4 years ago
parent 27a90a6a1b
commit 8eabb86c4b

@ -17,11 +17,11 @@ import numpy as np
from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import Validator as validator
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
from .._c_expression import MetaTensor as MetaTensor_
from .._c_expression import Tensor as Tensor_
from .._checkparam import Validator as validator
__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
@ -177,7 +177,7 @@ class Tensor(Tensor_):
return out
def __getitem__(self, index):
if isinstance(index, int) and index >= self.shape[0]:
if isinstance(index, int) and self.shape and index >= self.shape[0]:
raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
out = tensor_operator_registry.get('__getitem__')(self, index)
return out

@ -13,25 +13,27 @@
# limitations under the License.
# ============================================================================
"""cell"""
import inspect
import time
import gc
import inspect
import os
import time
from collections import OrderedDict
import numpy
from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from .. import context
from .._c_expression import init_backend, Cell_
from .._checkparam import Validator
from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec
from .._checkparam import Validator
from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend, Cell_
from ..ops.primitive import Primitive
from ..ops.operations import HookBackward
from ..common.tensor import Tensor, MetaTensor
from ..ops.functional import cast
from ..ops.operations import HookBackward
from ..ops.primitive import Primitive
from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor, MetaTensor
class Cell(Cell_):
@ -393,10 +395,10 @@ class Cell(Cell_):
raise AttributeError("Can not assign params before Cell.__init__() call.")
if name in self.__dict__:
if self.__dict__[name] is not None:
raise TypeError("Expected type is not in (Parameter, Cell), but got Parameter.")
raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
del self.__dict__[name]
if cells and name in cells:
raise TypeError("Expected type is Cell, but got Parameter.")
raise TypeError("The type of value should be Cell, but got Parameter.")
self.insert_param_to_cell(name, value)
elif isinstance(value, ParameterTuple):
if params is None:
@ -417,7 +419,7 @@ class Cell(Cell_):
if name in self.__dict__:
del self.__dict__[name]
if params and name in params:
raise TypeError("Expected type is Parameter, but got Cell.")
raise TypeError("The type of value should be Parameter, but got Cell.")
if self._auto_prefix:
value.update_parameters_name(name + '.')
cells[name] = value
@ -427,12 +429,13 @@ class Cell(Cell_):
if isinstance(value, Tensor) and self._params[name] is not None:
self._params[name].set_data(value)
elif value is not None:
raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value)))
raise TypeError(f"The type of value should be Parameter or ParameterTuple, "
f"but got {type(value).__name__}.")
else:
self.insert_param_to_cell(name, None)
elif cells and name in cells:
if value is not None:
raise TypeError("Expected type is cell, but got {}.".format(type(value)))
raise TypeError(f"The type of value should be cell, but got {type(value).__name__}.")
self._cells[name] = None
elif isinstance(value, Tensor):
if context.get_context("mode") == context.PYNATIVE_MODE:
@ -705,6 +708,7 @@ class Cell(Cell_):
new_p = param.init_data(layout, set_sliced=set_sliced)
replace[param] = new_p
return new_p
# replace all original usage.
cells = self.cells_and_names()
for _, cell in cells:

Loading…
Cancel
Save