From 94d0d45ab44781c96b612a44763c2afbf0cc4a5b Mon Sep 17 00:00:00 2001 From: jzg Date: Thu, 9 Jul 2020 15:05:30 +0800 Subject: [PATCH] increase the max size of tensor. --- .../ccsrc/kernel/aicpu/aicpu_kernel_build.cc | 16 ++++++++-------- mindspore/ccsrc/utils/convert_utils_base.h | 11 +++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index c83994b5f2..f602a6acd8 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input } else { auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); MS_EXCEPTION_IF_NULL(type_ptr); - int size_i = 1; + int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - IntMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); - input_size_list->push_back(IntToSize(size_i)); + LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + input_size_list->push_back(LongToSize(size_i)); } } return true; @@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); MS_EXCEPTION_IF_NULL(type_ptr); - int size_i = 1; + int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - IntMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); - output_size_list.push_back(IntToSize(size_i)); + LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + output_size_list.push_back(LongToSize(size_i)); } kernel_mod_ptr->SetOutputSizeList(output_size_list); return true; diff --git a/mindspore/ccsrc/utils/convert_utils_base.h b/mindspore/ccsrc/utils/convert_utils_base.h index 3638a43e6a..8960d6628b 100644 --- a/mindspore/ccsrc/utils/convert_utils_base.h +++ b/mindspore/ccsrc/utils/convert_utils_base.h @@ -102,6 +102,17 @@ inline void IntMulWithOverflowCheck(int a, int b, int *c) { *c = out; } +inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) { + int64_t out = a * b; + if (a != 0) { + bool ok = ((out / a) != b); + if (ok) { + MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + *c = out; +} + inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { size_t out = a * b; if (a != 0) {