!8507 [MSLITE] Check validation of input shapes before running graphes.

From: @wang_shaocong
Reviewed-by: @zhang_xue_tong
Signed-off-by:
pull/8507/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 73f11b3ca6

@ -24,7 +24,7 @@ typedef struct SplitParameter {
OpParameter op_parameter_;
SplitQuantArg quant_arg_;
int num_split_;
int split_sizes_[32];
int *split_sizes_;
int strides_[32];
int split_dim_;
int n_dims_;

@ -29,6 +29,13 @@ int Executor::CheckInputs(std::vector<Tensor *> &in_tensors) {
MS_LOG(ERROR) << "Graph input tensor data is nullptr";
return RET_ERROR;
}
auto shape = inTensor->shape();
bool valid = all_of(shape.begin(), shape.end(), [](int i) { return i > 0; });
if (!valid) {
MS_LOG(ERROR) << "The shape of input tensor contains zero or negative dimension,"
<< "check the model and assign the input shape with method Resize().";
return RET_ERROR;
}
}
return RET_OK;
}

@ -32,13 +32,19 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
auto param = reinterpret_cast<mindspore::lite::Split *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
split_param->op_parameter_.type_ = primitive->Type();
split_param->num_split_ = param->GetNumberSplit();
int *split_sizes = reinterpret_cast<int *>(malloc(split_param->num_split_ * sizeof(int)));
if (split_sizes == nullptr) {
MS_LOG(ERROR) << "malloc split size of SplitParameter failed.";
return nullptr;
}
memset(split_sizes, 0, split_param->num_split_ * sizeof(int));
split_param->split_sizes_ = split_sizes;
auto split_sizes_vector_ = param->GetSizeSplits();
int i = 0;
for (auto iter = split_sizes_vector_.begin(); iter != split_sizes_vector_.end(); iter++) {
split_param->split_sizes_[i++] = *iter;
}
split_param->split_dim_ = param->GetSplitDim();
split_param->num_split_ = param->GetNumberSplit();
return reinterpret_cast<OpParameter *>(split_param);
}
Registry SplitParameterRegistry(schema::PrimitiveType_Split, PopulateSplitParameter);

@ -294,7 +294,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
return ret;
}
ret = ChangeOpAxis(graph, node);
if (ret != RET_OK) {
if (ret == RET_NOT_SUPPORT) {
MS_LOG(INFO) << "not support to ChangeOpAxis";
return RET_OK;
} else if (ret != RET_OK) {
MS_LOG(INFO) << "no need to ChangeOpAxis";
return ret;
}

@ -251,6 +251,7 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node)
quant_param->zeroPoint = static_cast<int32_t>(onnx_node_attr.i());
}
}
quant_param->inited = true;
tensor->quantParams.emplace_back(std::move(quant_param));
} else {
MS_LOG(ERROR) << "unsupported data type " << tensor->dataType;
@ -369,6 +370,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name;
return;
}
quant_param->inited = true;
int argNum = 0;
for (const auto &onnx_node_attr : node.attribute()) {
if (onnx_node_attr.name() == "Y_scale") {
@ -384,6 +386,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
quant_param->zeroPoint = 0;
quant_param->min = FLT_MAX;
quant_param->max = FLT_MAX;
quant_param->inited = false;
}
dst_tensor->quantParams.emplace_back(std::move(quant_param));
if (argNum == 2) {

@ -39,9 +39,9 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
}
if (onnx_node.op_type() == "Int8Quantize") {
attr->srcT = kNumberTypeFloat32;
attr->dstT = kNumberTypeInt8;
attr->dstT = kNumberTypeUInt8;
} else if (onnx_node.op_type() == "Int8Dequantize") {
attr->srcT = kNumberTypeInt8;
attr->srcT = kNumberTypeUInt8;
attr->dstT = kNumberTypeFloat32;
} else {
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();

Loading…
Cancel
Save