increase the max size of tensor.

pull/2964/head
jzg 5 years ago
parent b4e3715897
commit 94d0d45ab4

@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
} else {
auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1;
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
input_size_list->push_back(IntToSize(size_i));
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
input_size_list->push_back(LongToSize(size_i));
}
}
return true;
@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1;
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
output_size_list.push_back(IntToSize(size_i));
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
output_size_list.push_back(LongToSize(size_i));
}
kernel_mod_ptr->SetOutputSizeList(output_size_list);
return true;

@ -102,6 +102,17 @@ inline void IntMulWithOverflowCheck(int a, int b, int *c) {
*c = out;
}
inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) {
int64_t out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
}
inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {
size_t out = a * b;
if (a != 0) {

Loading…
Cancel
Save