!6545 improve the recognition of Parameter object and raise error when convert keywordarg to pydata

Merge pull request !6545 from zhangbuxue/improve_the_recognition_of_Parameter_object_and_raise_error_when_convert_keywordarg_to_pydata
pull/6545/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4905de06bd

@ -118,5 +118,5 @@ class ClassMemberNamespace(Namespace):
except ValueError: except ValueError:
raise UnboundLocalError(name) raise UnboundLocalError(name)
except KeyError: except KeyError:
logger.warning(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', " logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
f"so will return None.") raise AttributeError(name)

@ -89,7 +89,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
} }
std::string param_name = py::cast<std::string>(name_attr); auto param_name = py::cast<std::string>(name_attr);
auto top_graph = Parser::GetTopFuncGraph(); auto top_graph = Parser::GetTopFuncGraph();
// if the parameter node has been created , return it // if the parameter node has been created , return it
AnfNodePtr para_node = nullptr; AnfNodePtr para_node = nullptr;
@ -115,7 +115,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
AnfNodePtr output = nullptr; AnfNodePtr output = nullptr;
if (py::hasattr(obj, "__parameter__")) { if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
auto param = ResolveParameterObj(func_graph, obj); auto param = ResolveParameterObj(func_graph, obj);
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr"; MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";

@ -1014,7 +1014,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
// create class instance // create class instance
auto obj = parse::data_converter::CreatePythonObject(class_type, params); auto obj = parse::data_converter::CreatePythonObject(class_type, params);
if (py::isinstance<py::none>(obj)) { if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type"; MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
<< " failed, only support create Cell or Primitive object.";
} }
// process the object // process the object

@ -1355,8 +1355,8 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
auto params = newfg->parameters(); auto params = newfg->parameters();
auto manager = Manage({newfg}, false); auto manager = Manage({newfg}, false);
if (args.size() > params.size()) { if (args.size() > params.size()) {
MS_EXCEPTION(ValueError) << "The number of arguments " << args.size() MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
<< " is more than the number of parameters required, which is " << params.size(); << " is more than the number of parameters required, which is " << params.size();
} }
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
ValuePtr value = PyAttrValue(args[i]); ValuePtr value = PyAttrValue(args[i]);

@ -147,7 +147,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
} else if (value->isa<None>()) { } else if (value->isa<None>()) {
ret = py::none(); ret = py::none();
} else { } else {
MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData."; MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
} }
return ret; return ret;
} }

@ -48,6 +48,8 @@ class Parameter(MetaTensor):
Note: Note:
Each parameter of Cell is represented by Parameter class. Each parameter of Cell is represented by Parameter class.
A Parameter has to belong to a Cell. A Parameter has to belong to a Cell.
If there is an operator in the network that requires part of the inputs to be Parameter,
then the Parameters as this part of the inputs are not allowed to be cast.
Args: Args:
default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized. default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized.

@ -22,6 +22,7 @@ from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type from .._c_expression import Primitive_, real_run_op, prim_type
from . import signature as sig from . import signature as sig
class Primitive(Primitive_): class Primitive(Primitive_):
""" """
Primitive is the base class of primitives in python. Primitive is the base class of primitives in python.
@ -168,7 +169,7 @@ class Primitive(Primitive_):
return type(self)(**self.init_attrs) return type(self)(**self.init_attrs)
def __repr__(self): def __repr__(self):
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list]) attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list])
info_str = f'Prim[{self.name}]' info_str = f'Prim[{self.name}]'
if attr: if attr:
info_str += f'<{attr}>' info_str += f'<{attr}>'
@ -425,6 +426,7 @@ def prim_attr_register(fn):
Returns: Returns:
function, original function. function, original function.
""" """
def deco(self, *args, **kwargs): def deco(self, *args, **kwargs):
if isinstance(self, PrimitiveWithInfer): if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__) PrimitiveWithInfer.__init__(self, self.__class__.__name__)
@ -442,6 +444,7 @@ def prim_attr_register(fn):
self.add_prim_attr(name, value) self.add_prim_attr(name, value)
self.init_attrs[name] = value self.init_attrs[name] = value
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
deco.decorated_func = fn deco.decorated_func = fn
return deco return deco
@ -470,6 +473,7 @@ def constexpr(fn=None, get_instance=True, name=None):
>>> return len(x) >>> return len(x)
>>> assert tuple_len_class()(a) == 2 >>> assert tuple_len_class()(a) == 2
""" """
def deco(fn): def deco(fn):
class CompileOp(PrimitiveWithInfer): class CompileOp(PrimitiveWithInfer):
def __init__(self): def __init__(self):
@ -479,9 +483,11 @@ def constexpr(fn=None, get_instance=True, name=None):
def infer_value(self, *args): def infer_value(self, *args):
return fn(*args) return fn(*args)
if get_instance: if get_instance:
return CompileOp() return CompileOp()
return CompileOp return CompileOp
if fn is not None: if fn is not None:
return deco(fn) return deco(fn)
return deco return deco

@ -143,6 +143,7 @@ class TestSummaryOps:
(SummaryEnum.TENSOR.value, Tensor(0)), (SummaryEnum.TENSOR.value, Tensor(0)),
(SummaryEnum.HISTOGRAM.value, Tensor(0)) (SummaryEnum.HISTOGRAM.value, Tensor(0))
]) ])
def test_value_shape_invalid(self, summary_type, value): def test_value_shape_invalid(self, summary_type, value):
"""Test invalid shape of every summary operators.""" """Test invalid shape of every summary operators."""
net = SummaryNet(summary_type, tag='tag', data=value) net = SummaryNet(summary_type, tag='tag', data=value)

@ -122,7 +122,6 @@ class SummaryDemo(nn.Cell):
self.s("y1", y) self.s("y1", y)
return z return z
def test_tensor_summary_with_ge(): def test_tensor_summary_with_ge():
""" test_tensor_summary_with_ge """ """ test_tensor_summary_with_ge """
log.debug("begin test_tensor_summary_with_ge") log.debug("begin test_tensor_summary_with_ge")

Loading…
Cancel
Save