From 94d0d45ab44781c96b612a44763c2afbf0cc4a5b Mon Sep 17 00:00:00 2001 From: jzg Date: Thu, 9 Jul 2020 15:05:30 +0800 Subject: [PATCH 1/2] increase the max size of tensor. --- .../ccsrc/kernel/aicpu/aicpu_kernel_build.cc | 16 ++++++++-------- mindspore/ccsrc/utils/convert_utils_base.h | 11 +++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index c83994b5f2..f602a6acd8 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input } else { auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); MS_EXCEPTION_IF_NULL(type_ptr); - int size_i = 1; + int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - IntMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); - input_size_list->push_back(IntToSize(size_i)); + LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + input_size_list->push_back(LongToSize(size_i)); } } return true; @@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); MS_EXCEPTION_IF_NULL(type_ptr); - int size_i = 1; + int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - IntMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); - output_size_list.push_back(IntToSize(size_i)); + LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + output_size_list.push_back(LongToSize(size_i)); } kernel_mod_ptr->SetOutputSizeList(output_size_list); return true; diff --git a/mindspore/ccsrc/utils/convert_utils_base.h b/mindspore/ccsrc/utils/convert_utils_base.h index 3638a43e6a..8960d6628b 100644 --- a/mindspore/ccsrc/utils/convert_utils_base.h +++ b/mindspore/ccsrc/utils/convert_utils_base.h @@ -102,6 +102,17 @@ inline void IntMulWithOverflowCheck(int a, int b, int *c) { *c = out; } +inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) { + int64_t out = a * b; + if (a != 0) { + bool ok = ((out / a) != b); + if (ok) { + MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + *c = out; +} + inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { size_t out = a * b; if (a != 0) { From fb90ff164bf2e6174844850ce12158721ea328a2 Mon Sep 17 00:00:00 2001 From: jzg Date: Thu, 9 Jul 2020 15:05:30 +0800 Subject: [PATCH 2/2] increase the max size of tensor. --- .../ccsrc/kernel/aicpu/aicpu_kernel_build.cc | 8 ++++---- mindspore/ccsrc/operator/prim_structures.cc | 2 +- mindspore/ccsrc/utils/convert_utils_base.h | 16 ++++++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index f602a6acd8..99e792216f 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -67,13 +67,13 @@ bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input MS_EXCEPTION_IF_NULL(type_ptr); int64_t size_i = 1; for (size_t j = 0; j < shape_i.size(); j++) { - LongMulWithOverflowCheck(size_i, static_cast(shape_i[j]), &size_i); + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); input_size_list->push_back(LongToSize(size_i)); } } @@ -99,13 +99,13 @@ bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr(shape_i[j]), &size_i); + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); } size_t type_byte = GetTypeByte(type_ptr); if (type_byte == 0) { return false; } - LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); output_size_list.push_back(LongToSize(size_i)); } kernel_mod_ptr->SetOutputSizeList(output_size_list); diff --git a/mindspore/ccsrc/operator/prim_structures.cc b/mindspore/ccsrc/operator/prim_structures.cc index ba924f5ca4..3d0cba5e83 100644 --- a/mindspore/ccsrc/operator/prim_structures.cc +++ b/mindspore/ccsrc/operator/prim_structures.cc @@ -587,7 +587,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr int result = 1; for (size_t i = 0; i < shpx_data.size(); i++) { int value = GetValue(shpx_data[i]); - IntMulWithOverflowCheck(result, value, &result); + result = IntMulWithOverflowCheck(result, value); } auto result_v = MakeValue(result); diff --git a/mindspore/ccsrc/utils/convert_utils_base.h b/mindspore/ccsrc/utils/convert_utils_base.h index 8960d6628b..b9a38f997f 100644 --- a/mindspore/ccsrc/utils/convert_utils_base.h +++ b/mindspore/ccsrc/utils/convert_utils_base.h @@ -91,26 +91,26 @@ inline unsigned int UlongToUint(size_t u) { return static_cast(u); } -inline void IntMulWithOverflowCheck(int a, int b, int *c) { +inline int IntMulWithOverflowCheck(int a, int b) { int out = a * b; if (a != 0) { - bool ok = ((out / a) != b); - if (ok) { + bool overflow = ((out / a) != b); + if (overflow) { MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; } } - *c = out; + return out; } -inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) { +inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) { int64_t out = a * b; if (a != 0) { - bool ok = ((out / a) != b); - if (ok) { + bool overflow = ((out / a) != b); + if (overflow) { MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; } } - *c = out; + return out; } inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {