From 8eabb86c4b4a31704f11b565096a702f24fbe89d Mon Sep 17 00:00:00 2001 From: buxue Date: Thu, 10 Dec 2020 11:04:29 +0800 Subject: [PATCH] optimize error info when assign parameter to non-parameter attr --- mindspore/common/tensor.py | 8 ++++---- mindspore/nn/cell.py | 28 ++++++++++++++++------------ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 85ec361caf..e14268a61a 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -17,11 +17,11 @@ import numpy as np from mindspore import log as logger from mindspore.communication.management import get_rank, get_group_size -from .._c_expression import Tensor as Tensor_ -from .._c_expression import MetaTensor as MetaTensor_ -from .._checkparam import Validator as validator from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry +from .._c_expression import MetaTensor as MetaTensor_ +from .._c_expression import Tensor as Tensor_ +from .._checkparam import Validator as validator __all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor'] np_types = (np.int8, np.int16, np.int32, np.int64, @@ -177,7 +177,7 @@ class Tensor(Tensor_): return out def __getitem__(self, index): - if isinstance(index, int) and index >= self.shape[0]: + if isinstance(index, int) and self.shape and index >= self.shape[0]: raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0])) out = tensor_operator_registry.get('__getitem__')(self, index) return out diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 87a0d53169..a126766a90 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -13,25 +13,27 @@ # limitations under the License. # ============================================================================ """cell""" -import inspect -import time import gc +import inspect import os +import time from collections import OrderedDict + import numpy + from mindspore import log as logger from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from .. import context +from .._c_expression import init_backend, Cell_ +from .._checkparam import Validator from ..common import dtype as mstype from ..common.api import _executor, _pynative_exec -from .._checkparam import Validator from ..common.parameter import Parameter, ParameterTuple -from .._c_expression import init_backend, Cell_ -from ..ops.primitive import Primitive -from ..ops.operations import HookBackward +from ..common.tensor import Tensor, MetaTensor from ..ops.functional import cast +from ..ops.operations import HookBackward +from ..ops.primitive import Primitive from ..parallel._tensor import _load_tensor_by_layout -from ..common.tensor import Tensor, MetaTensor class Cell(Cell_): @@ -393,10 +395,10 @@ class Cell(Cell_): raise AttributeError("Can not assign params before Cell.__init__() call.") if name in self.__dict__: if self.__dict__[name] is not None: - raise TypeError("Expected type is not in (Parameter, Cell), but got Parameter.") + raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.") del self.__dict__[name] if cells and name in cells: - raise TypeError("Expected type is Cell, but got Parameter.") + raise TypeError("The type of value should be Cell, but got Parameter.") self.insert_param_to_cell(name, value) elif isinstance(value, ParameterTuple): if params is None: @@ -417,7 +419,7 @@ class Cell(Cell_): if name in self.__dict__: del self.__dict__[name] if params and name in params: - raise TypeError("Expected type is Parameter, but got Cell.") + raise TypeError("The type of value should be Parameter, but got Cell.") if self._auto_prefix: value.update_parameters_name(name + '.') cells[name] = value @@ -427,12 +429,13 @@ class Cell(Cell_): if isinstance(value, Tensor) and self._params[name] is not None: self._params[name].set_data(value) elif value is not None: - raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value))) + raise TypeError(f"The type of value should be Parameter or ParameterTuple, " + f"but got {type(value).__name__}.") else: self.insert_param_to_cell(name, None) elif cells and name in cells: if value is not None: - raise TypeError("Expected type is cell, but got {}.".format(type(value))) + raise TypeError(f"The type of value should be cell, but got {type(value).__name__}.") self._cells[name] = None elif isinstance(value, Tensor): if context.get_context("mode") == context.PYNATIVE_MODE: @@ -705,6 +708,7 @@ class Cell(Cell_): new_p = param.init_data(layout, set_sliced=set_sliced) replace[param] = new_p return new_p + # replace all original usage. cells = self.cells_and_names() for _, cell in cells: