!4000 improve interface '__bool__' for tensor

Merge pull request !4000 from zhangbuxue/improve_bool_for_tensor
pull/4000/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 8dec74908a

@ -17,7 +17,6 @@
"""standard_method""" """standard_method"""
from dataclasses import dataclass from dataclasses import dataclass
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from ...ops import functional as F from ...ops import functional as F
from ...ops import operations as P from ...ops import operations as P
from ...ops.primitive import constexpr from ...ops.primitive import constexpr
@ -206,13 +205,11 @@ def const_tensor_to_bool(x):
if x is None: if x is None:
raise ValueError("Only constant tensor bool can be converted to bool") raise ValueError("Only constant tensor bool can be converted to bool")
x = x.asnumpy() x = x.asnumpy()
if x.shape not in ((), (1,)):
raise ValueError("The truth value of an array with several elements is ambiguous.")
if x.shape == (): if x.shape == ():
value = bool(x) return bool(x)
else: if x.shape == (1,):
value = bool(x[0]) return bool(x[0])
return value raise ValueError("The truth value of an array with several elements is ambiguous.")
def tensor_bool(x): def tensor_bool(x):
@ -349,6 +346,3 @@ def list_append(self_, item):
def to_array(x): def to_array(x):
"""Implementation of `to_array`.""" """Implementation of `to_array`."""
return x.__ms_to_array__() return x.__ms_to_array__()
tensor_operator_registry.register('__bool__', tensor_bool)

@ -73,7 +73,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
namespace { namespace {
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python tuple"; MS_LOG(DEBUG) << "Converting python tuple";
py::tuple tuple = obj.cast<py::tuple>(); auto tuple = obj.cast<py::tuple>();
std::vector<ValuePtr> value_list; std::vector<ValuePtr> value_list;
for (size_t it = 0; it < tuple.size(); ++it) { for (size_t it = 0; it < tuple.size(); ++it) {
ValuePtr out = nullptr; ValuePtr out = nullptr;
@ -91,7 +91,7 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur
bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python list"; MS_LOG(DEBUG) << "Converting python list";
py::list list = obj.cast<py::list>(); auto list = obj.cast<py::list>();
std::vector<ValuePtr> value_list; std::vector<ValuePtr> value_list;
for (size_t it = 0; it < list.size(); ++it) { for (size_t it = 0; it < list.size(); ++it) {
ValuePtr out = nullptr; ValuePtr out = nullptr;
@ -124,7 +124,7 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa
bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python dict"; MS_LOG(DEBUG) << "Converting python dict";
py::dict dict_values = obj.cast<py::dict>(); auto dict_values = obj.cast<py::dict>();
std::vector<std::pair<std::string, ValuePtr>> key_values; std::vector<std::pair<std::string, ValuePtr>> key_values;
for (auto item : dict_values) { for (auto item : dict_values) {
if (!py::isinstance<py::str>(item.first)) { if (!py::isinstance<py::str>(item.first)) {
@ -208,7 +208,7 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
bool ConvertSlice(const py::object &obj, ValuePtr *const data) { bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting slice object"; MS_LOG(DEBUG) << "Converting slice object";
py::slice slice_obj = obj.cast<py::slice>(); auto slice_obj = obj.cast<py::slice>();
auto convert_func = [obj](std::string attr) -> ValuePtr { auto convert_func = [obj](std::string attr) -> ValuePtr {
auto py_attr = py::getattr(obj, attr.c_str()); auto py_attr = py::getattr(obj, attr.c_str());
if (py::isinstance<py::none>(py_attr)) { if (py::isinstance<py::none>(py_attr)) {
@ -335,7 +335,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::isinstance<MetaTensor>(obj)) { } else if (py::isinstance<MetaTensor>(obj)) {
converted = obj.cast<MetaTensorPtr>(); converted = obj.cast<MetaTensorPtr>();
} else if (py::isinstance<EnvInstance>(obj)) { } else if (py::isinstance<EnvInstance>(obj)) {
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>(); auto env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env; converted = env;
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
@ -374,7 +374,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
data_converter::CacheObjectValue(obj_id, func_graph); data_converter::CacheObjectValue(obj_id, func_graph);
if (obj_key != "") { if (!obj_key.empty()) {
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
data_converter::SetObjGraphValue(obj_key, func_graph); data_converter::SetObjGraphValue(obj_key, func_graph);
} }
@ -440,7 +440,7 @@ bool IsCellInstance(const py::object &obj) {
py::object CreatePythonObject(const py::object &type, const py::tuple &params) { py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object obj; py::object obj;
if (params.size() == 0) { if (params.empty()) {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
} else { } else {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
@ -499,7 +499,7 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
ClassAttrVector attributes; ClassAttrVector attributes;
py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
for (auto &item : names) { for (auto &item : names) {
TypePtr type_value = item.second.cast<TypePtr>(); auto type_value = item.second.cast<TypePtr>();
MS_EXCEPTION_IF_NULL(type_value); MS_EXCEPTION_IF_NULL(type_value);
MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")"; MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value)); attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
@ -508,8 +508,8 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
std::unordered_map<std::string, ValuePtr> methods_map; std::unordered_map<std::string, ValuePtr> methods_map;
py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
for (auto &item : methods) { for (auto &item : methods) {
std::string fun_name = item.first.cast<std::string>(); auto fun_name = item.first.cast<std::string>();
py::object obj = py::cast<py::object>(item.second); auto obj = py::cast<py::object>(item.second);
std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name); std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
methods_map[fun_name] = method_obj; methods_map[fun_name] = method_obj;
} }

@ -108,8 +108,12 @@ class Tensor(Tensor_):
return out return out
def __bool__(self): def __bool__(self):
out = tensor_operator_registry.get('__bool__')(self) data = self.asnumpy()
return out if data.shape == ():
return bool(data)
if data.shape == (1,):
return bool(data[0])
raise ValueError("The truth value of an array with several elements is ambiguous.")
def __pos__(self): def __pos__(self):
return self return self

@ -35,7 +35,6 @@ def test_dtype_and_shape_as_attr():
dtype = x.dtype dtype = x.dtype
return shape, dtype return shape, dtype
net = Net() net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32)) x = Tensor(np.ones([1, 2, 3], np.int32))
ret = net(x) ret = net(x)
@ -55,7 +54,6 @@ def test_dtype_and_shape_as_attr_to_new_tensor():
y = self.fill(dtype, shape, self.value) y = self.fill(dtype, shape, self.value)
return y return y
net = Net(2.2) net = Net(2.2)
x = Tensor(np.ones([1, 2, 3], np.float32)) x = Tensor(np.ones([1, 2, 3], np.float32))
ret = net(x) ret = net(x)
@ -71,7 +69,6 @@ def test_type_not_have_the_attr():
shape = x.shapes shape = x.shapes
return shape return shape
net = Net() net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32)) x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as ex: with pytest.raises(RuntimeError) as ex:
@ -88,7 +85,6 @@ def test_type_not_have_the_method():
shape = x.dtypes() shape = x.dtypes()
return shape return shape
net = Net() net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32)) x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as ex: with pytest.raises(RuntimeError) as ex:

Loading…
Cancel
Save