increase the max size of tensor.

pull/2964/head
jzg 5 years ago
parent b4e3715897
commit 94d0d45ab4

@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
} else { } else {
auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr); MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1; int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) { for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i); LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
} }
size_t type_byte = GetTypeByte(type_ptr); size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) { if (type_byte == 0) {
return false; return false;
} }
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
input_size_list->push_back(IntToSize(size_i)); input_size_list->push_back(LongToSize(size_i));
} }
} }
return true; return true;
@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr); MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1; int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) { for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i); LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
} }
size_t type_byte = GetTypeByte(type_ptr); size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) { if (type_byte == 0) {
return false; return false;
} }
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
output_size_list.push_back(IntToSize(size_i)); output_size_list.push_back(LongToSize(size_i));
} }
kernel_mod_ptr->SetOutputSizeList(output_size_list); kernel_mod_ptr->SetOutputSizeList(output_size_list);
return true; return true;

@ -102,6 +102,17 @@ inline void IntMulWithOverflowCheck(int a, int b, int *c) {
*c = out; *c = out;
} }
inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) {
int64_t out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
}
inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {
size_t out = a * b; size_t out = a * b;
if (a != 0) { if (a != 0) {

Loading…
Cancel
Save