optimize error info when assign parameter to non-parameter attr

pull/9761/head
buxue 5 years ago
parent 27a90a6a1b
commit 8eabb86c4b

@ -17,11 +17,11 @@ import numpy as np
from mindspore import log as logger from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import Validator as validator
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
from .._c_expression import MetaTensor as MetaTensor_
from .._c_expression import Tensor as Tensor_
from .._checkparam import Validator as validator
__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor'] __all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64, np_types = (np.int8, np.int16, np.int32, np.int64,
@ -177,7 +177,7 @@ class Tensor(Tensor_):
return out return out
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, int) and index >= self.shape[0]: if isinstance(index, int) and self.shape and index >= self.shape[0]:
raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0])) raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
out = tensor_operator_registry.get('__getitem__')(self, index) out = tensor_operator_registry.get('__getitem__')(self, index)
return out return out

@ -13,25 +13,27 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""cell""" """cell"""
import inspect
import time
import gc import gc
import inspect
import os import os
import time
from collections import OrderedDict from collections import OrderedDict
import numpy import numpy
from mindspore import log as logger from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from .. import context from .. import context
from .._c_expression import init_backend, Cell_
from .._checkparam import Validator
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec from ..common.api import _executor, _pynative_exec
from .._checkparam import Validator
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend, Cell_ from ..common.tensor import Tensor, MetaTensor
from ..ops.primitive import Primitive
from ..ops.operations import HookBackward
from ..ops.functional import cast from ..ops.functional import cast
from ..ops.operations import HookBackward
from ..ops.primitive import Primitive
from ..parallel._tensor import _load_tensor_by_layout from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor, MetaTensor
class Cell(Cell_): class Cell(Cell_):
@ -393,10 +395,10 @@ class Cell(Cell_):
raise AttributeError("Can not assign params before Cell.__init__() call.") raise AttributeError("Can not assign params before Cell.__init__() call.")
if name in self.__dict__: if name in self.__dict__:
if self.__dict__[name] is not None: if self.__dict__[name] is not None:
raise TypeError("Expected type is not in (Parameter, Cell), but got Parameter.") raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
del self.__dict__[name] del self.__dict__[name]
if cells and name in cells: if cells and name in cells:
raise TypeError("Expected type is Cell, but got Parameter.") raise TypeError("The type of value should be Cell, but got Parameter.")
self.insert_param_to_cell(name, value) self.insert_param_to_cell(name, value)
elif isinstance(value, ParameterTuple): elif isinstance(value, ParameterTuple):
if params is None: if params is None:
@ -417,7 +419,7 @@ class Cell(Cell_):
if name in self.__dict__: if name in self.__dict__:
del self.__dict__[name] del self.__dict__[name]
if params and name in params: if params and name in params:
raise TypeError("Expected type is Parameter, but got Cell.") raise TypeError("The type of value should be Parameter, but got Cell.")
if self._auto_prefix: if self._auto_prefix:
value.update_parameters_name(name + '.') value.update_parameters_name(name + '.')
cells[name] = value cells[name] = value
@ -427,12 +429,13 @@ class Cell(Cell_):
if isinstance(value, Tensor) and self._params[name] is not None: if isinstance(value, Tensor) and self._params[name] is not None:
self._params[name].set_data(value) self._params[name].set_data(value)
elif value is not None: elif value is not None:
raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value))) raise TypeError(f"The type of value should be Parameter or ParameterTuple, "
f"but got {type(value).__name__}.")
else: else:
self.insert_param_to_cell(name, None) self.insert_param_to_cell(name, None)
elif cells and name in cells: elif cells and name in cells:
if value is not None: if value is not None:
raise TypeError("Expected type is cell, but got {}.".format(type(value))) raise TypeError(f"The type of value should be cell, but got {type(value).__name__}.")
self._cells[name] = None self._cells[name] = None
elif isinstance(value, Tensor): elif isinstance(value, Tensor):
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
@ -705,6 +708,7 @@ class Cell(Cell_):
new_p = param.init_data(layout, set_sliced=set_sliced) new_p = param.init_data(layout, set_sliced=set_sliced)
replace[param] = new_p replace[param] = new_p
return new_p return new_p
# replace all original usage. # replace all original usage.
cells = self.cells_and_names() cells = self.cells_and_names()
for _, cell in cells: for _, cell in cells:

Loading…
Cancel
Save