From 483c8a179aa188f7e5322bbfca53014cdd53e6fe Mon Sep 17 00:00:00 2001 From: buxue Date: Sat, 19 Sep 2020 11:43:30 +0800 Subject: [PATCH] improve the recognition of Parameter object and raise error when convert keywordarg to pydata --- mindspore/_extends/parse/namespace.py | 4 ++-- mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 4 ++-- mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc | 3 ++- mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 4 ++-- mindspore/ccsrc/utils/convert_utils_py.cc | 2 +- mindspore/common/parameter.py | 2 ++ mindspore/ops/primitive.py | 8 +++++++- .../train/summary/test_summary_ops_params_valid_check.py | 1 + tests/ut/python/train/summary/test_tensor_summary.py | 1 - 9 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mindspore/_extends/parse/namespace.py b/mindspore/_extends/parse/namespace.py index ef2c2b1fbc..8421cec212 100644 --- a/mindspore/_extends/parse/namespace.py +++ b/mindspore/_extends/parse/namespace.py @@ -118,5 +118,5 @@ class ClassMemberNamespace(Namespace): except ValueError: raise UnboundLocalError(name) except KeyError: - logger.warning(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', " - f"so will return None.") + logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") + raise AttributeError(name) diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 60ec92167d..ce72c19cf9 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -89,7 +89,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } - std::string param_name = py::cast(name_attr); + auto param_name = py::cast(name_attr); auto top_graph = Parser::GetTopFuncGraph(); // if the parameter node has been created , return it AnfNodePtr para_node = nullptr; @@ -115,7 +115,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { AnfNodePtr output = nullptr; - if (py::hasattr(obj, "__parameter__")) { + if (py::hasattr(obj, "__parameter__") && py::isinstance(obj)) { auto param = ResolveParameterObj(func_graph, obj); if (param == nullptr) { MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr"; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 4d05f7d064..855104dccd 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1014,7 +1014,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { // create class instance auto obj = parse::data_converter::CreatePythonObject(class_type, params); if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type"; + MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type) + << " failed, only support create Cell or Primitive object."; } // process the object diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 93dd17c8db..60df3a5220 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1361,8 +1361,8 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje auto params = newfg->parameters(); auto manager = Manage({newfg}, false); if (args.size() > params.size()) { - MS_EXCEPTION(ValueError) << "The number of arguments " << args.size() - << " is more than the number of parameters required, which is " << params.size(); + MS_EXCEPTION(TypeError) << "The number of arguments " << args.size() + << " is more than the number of parameters required, which is " << params.size(); } for (size_t i = 0; i < args.size(); i++) { ValuePtr value = PyAttrValue(args[i]); diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index d09d4485e6..d85593a22a 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -147,7 +147,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) { } else if (value->isa()) { ret = py::none(); } else { - MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData."; + MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData."; } return ret; } diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index cd4ab384a1..1213f6acb8 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -48,6 +48,8 @@ class Parameter(MetaTensor): Note: Each parameter of Cell is represented by Parameter class. A Parameter has to belong to a Cell. + If there is an operator in the network that requires part of the inputs to be Parameter, + then the Parameters as this part of the inputs are not allowed to be cast. Args: default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized. diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 624848cf21..1b470b36fd 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -22,6 +22,7 @@ from mindspore import context from .._c_expression import Primitive_, real_run_op, prim_type from . import signature as sig + class Primitive(Primitive_): """ Primitive is the base class of primitives in python. @@ -168,7 +169,7 @@ class Primitive(Primitive_): return type(self)(**self.init_attrs) def __repr__(self): - attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list]) + attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list]) info_str = f'Prim[{self.name}]' if attr: info_str += f'<{attr}>' @@ -425,6 +426,7 @@ def prim_attr_register(fn): Returns: function, original function. """ + def deco(self, *args, **kwargs): if isinstance(self, PrimitiveWithInfer): PrimitiveWithInfer.__init__(self, self.__class__.__name__) @@ -442,6 +444,7 @@ def prim_attr_register(fn): self.add_prim_attr(name, value) self.init_attrs[name] = value fn(self, *args, **kwargs) + deco.decorated_func = fn return deco @@ -470,6 +473,7 @@ def constexpr(fn=None, get_instance=True, name=None): >>> return len(x) >>> assert tuple_len_class()(a) == 2 """ + def deco(fn): class CompileOp(PrimitiveWithInfer): def __init__(self): @@ -479,9 +483,11 @@ def constexpr(fn=None, get_instance=True, name=None): def infer_value(self, *args): return fn(*args) + if get_instance: return CompileOp() return CompileOp + if fn is not None: return deco(fn) return deco diff --git a/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py b/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py index 4b5180b963..6a7577bdf1 100644 --- a/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py +++ b/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py @@ -143,6 +143,7 @@ class TestSummaryOps: (SummaryEnum.TENSOR.value, Tensor(0)), (SummaryEnum.HISTOGRAM.value, Tensor(0)) ]) + def test_value_shape_invalid(self, summary_type, value): """Test invalid shape of every summary operators.""" net = SummaryNet(summary_type, tag='tag', data=value) diff --git a/tests/ut/python/train/summary/test_tensor_summary.py b/tests/ut/python/train/summary/test_tensor_summary.py index 9a028a8a12..2bbd7fa56b 100644 --- a/tests/ut/python/train/summary/test_tensor_summary.py +++ b/tests/ut/python/train/summary/test_tensor_summary.py @@ -122,7 +122,6 @@ class SummaryDemo(nn.Cell): self.s("y1", y) return z - def test_tensor_summary_with_ge(): """ test_tensor_summary_with_ge """ log.debug("begin test_tensor_summary_with_ge")