!6400 add overflow check for make_range and optimize isinstance processing

Merge pull request !6400 from zhangbuxue/add_overflow_check_for_make_range_and_optimize_isinstance_processing
pull/6400/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 8346da267b

@ -173,7 +173,7 @@ def check_type_same(x_type, base_type):
"""Check x_type is same as base_type.""" """Check x_type is same as base_type."""
if mstype.issubclass_(x_type, base_type): if mstype.issubclass_(x_type, base_type):
return True return True
raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.") return False
@constexpr @constexpr

@ -489,15 +489,25 @@ AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr
if (slide.step <= 0) { if (slide.step <= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
} }
for (int i = slide.start; i < slide.stop; i += slide.step) { for (int i = slide.start; i < slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i)); args.push_back(abstract::FromValue(i));
if (i > 0 && INT_MAX - i < slide.step) {
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
"will cause integer overflow.";
}
} }
} else { } else {
if (slide.step >= 0) { if (slide.step >= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
} }
for (int i = slide.start; i > slide.stop; i += slide.step) { for (int i = slide.start; i > slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i)); args.push_back(abstract::FromValue(i));
if (i < 0 && INT_MIN - i > slide.step) {
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
"will cause integer overflow.";
}
} }
} }

@ -268,7 +268,7 @@ def _tensor_index_by_tuple_slice(data, t):
def tensor_index_by_tuple(data, tuple_index): def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types""" """Tensor getitem by tuple of various types"""
if len(tuple_index) == 1: if len(tuple_index) == 1:
return data[tuple_index[0]] return data[tuple_index[0]]
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR: if index_elements_type == const_utils.NO_TENSOR:

@ -40,17 +40,17 @@ def test_number_not_in_tuple():
if self.number_in not in self.tuple_: if self.number_in not in self.tuple_:
ret += 1 ret += 1
if self.number_not_in not in self.tuple_: if self.number_not_in not in self.tuple_:
ret += 1 ret += 2
if self.number_in not in self.list_: if self.number_in not in self.list_:
ret += 3 ret += 3
if self.number_not_in not in self.list_: if self.number_not_in not in self.list_:
ret += 3 ret += 4
if self.str_in not in self.dict_: if self.str_in not in self.dict_:
ret += 5 ret += 5
if self.str_not_in not in self.dict_: if self.str_not_in not in self.dict_:
ret += 5 ret += 6
return ret return ret
net = Net() net = Net()
output = net() output = net()
assert output == 9 assert output == 12

Loading…
Cancel
Save