improve the recognition of Parameter object and raise error when convert keywordarg to pydata

pull/6545/head
buxue 4 years ago
parent 473b9614a7
commit 483c8a179a

@ -118,5 +118,5 @@ class ClassMemberNamespace(Namespace):
except ValueError:
raise UnboundLocalError(name)
except KeyError:
logger.warning(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', "
f"so will return None.")
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name)

@ -89,7 +89,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
std::string param_name = py::cast<std::string>(name_attr);
auto param_name = py::cast<std::string>(name_attr);
auto top_graph = Parser::GetTopFuncGraph();
// if the parameter node has been created , return it
AnfNodePtr para_node = nullptr;
@ -115,7 +115,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
AnfNodePtr output = nullptr;
if (py::hasattr(obj, "__parameter__")) {
if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
auto param = ResolveParameterObj(func_graph, obj);
if (param == nullptr) {
MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";

@ -1014,7 +1014,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
// create class instance
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type";
MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
<< " failed, only support create Cell or Primitive object.";
}
// process the object

@ -1361,8 +1361,8 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
if (args.size() > params.size()) {
MS_EXCEPTION(ValueError) << "The number of arguments " << args.size()
<< " is more than the number of parameters required, which is " << params.size();
MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
<< " is more than the number of parameters required, which is " << params.size();
}
for (size_t i = 0; i < args.size(); i++) {
ValuePtr value = PyAttrValue(args[i]);

@ -147,7 +147,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
} else if (value->isa<None>()) {
ret = py::none();
} else {
MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
}
return ret;
}

@ -48,6 +48,8 @@ class Parameter(MetaTensor):
Note:
Each parameter of Cell is represented by Parameter class.
A Parameter has to belong to a Cell.
If there is an operator in the network that requires part of the inputs to be Parameter,
then the Parameters as this part of the inputs are not allowed to be cast.
Args:
default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized.

@ -22,6 +22,7 @@ from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
from . import signature as sig
class Primitive(Primitive_):
"""
Primitive is the base class of primitives in python.
@ -168,7 +169,7 @@ class Primitive(Primitive_):
return type(self)(**self.init_attrs)
def __repr__(self):
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list])
info_str = f'Prim[{self.name}]'
if attr:
info_str += f'<{attr}>'
@ -425,6 +426,7 @@ def prim_attr_register(fn):
Returns:
function, original function.
"""
def deco(self, *args, **kwargs):
if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
@ -442,6 +444,7 @@ def prim_attr_register(fn):
self.add_prim_attr(name, value)
self.init_attrs[name] = value
fn(self, *args, **kwargs)
deco.decorated_func = fn
return deco
@ -470,6 +473,7 @@ def constexpr(fn=None, get_instance=True, name=None):
>>> return len(x)
>>> assert tuple_len_class()(a) == 2
"""
def deco(fn):
class CompileOp(PrimitiveWithInfer):
def __init__(self):
@ -479,9 +483,11 @@ def constexpr(fn=None, get_instance=True, name=None):
def infer_value(self, *args):
return fn(*args)
if get_instance:
return CompileOp()
return CompileOp
if fn is not None:
return deco(fn)
return deco

@ -143,6 +143,7 @@ class TestSummaryOps:
(SummaryEnum.TENSOR.value, Tensor(0)),
(SummaryEnum.HISTOGRAM.value, Tensor(0))
])
def test_value_shape_invalid(self, summary_type, value):
"""Test invalid shape of every summary operators."""
net = SummaryNet(summary_type, tag='tag', data=value)

@ -122,7 +122,6 @@ class SummaryDemo(nn.Cell):
self.s("y1", y)
return z
def test_tensor_summary_with_ge():
""" test_tensor_summary_with_ge """
log.debug("begin test_tensor_summary_with_ge")

Loading…
Cancel
Save