diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index f602a6acd8..99e792216f 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -67,13 +67,13 @@ bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input MS_EXCEPTION_IF_NULL(type_ptr); int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); input_size_list->push_back(LongToSize(size_i)); } } @@ -99,13 +99,13 @@ bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr(shape_i[j]), &size_i); + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); output_size_list.push_back(LongToSize(size_i)); } kernel_mod_ptr->SetOutputSizeList(output_size_list); diff --git a/mindspore/ccsrc/operator/prim_structures.cc b/mindspore/ccsrc/operator/prim_structures.cc index ba924f5ca4..3d0cba5e83 100644 --- a/mindspore/ccsrc/operator/prim_structures.cc +++ b/mindspore/ccsrc/operator/prim_structures.cc @@ -587,7 +587,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr int result = 1; for (size_t i = 0; i < shpx_data.size(); i++) { int value = GetValue(shpx_data[i]); - IntMulWithOverflowCheck(result, value, &result); + result = IntMulWithOverflowCheck(result, value); } auto result_v = MakeValue(result); diff --git a/mindspore/ccsrc/utils/convert_utils_base.h b/mindspore/ccsrc/utils/convert_utils_base.h index 8960d6628b..b9a38f997f 100644 --- a/mindspore/ccsrc/utils/convert_utils_base.h +++ b/mindspore/ccsrc/utils/convert_utils_base.h @@ -91,26 +91,26 @@ inline unsigned int UlongToUint(size_t u) { return static_cast(u); } -inline void IntMulWithOverflowCheck(int a, int b, int *c) { +inline int IntMulWithOverflowCheck(int a, int b) { int out = a * b; if (a != 0) { - bool ok = ((out / a) != b); - if (ok) { + bool overflow = ((out / a) != b); + if (overflow) { MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; } } - *c = out; + return out; } -inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) { +inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) { int64_t out = a * b; if (a != 0) { - bool ok = ((out / a) != b); - if (ok) { + bool overflow = ((out / a) != b); + if (overflow) { MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; } } - *c = out; + return out; } inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {