|
|
|
@ -58,40 +58,72 @@ std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve
|
|
|
|
|
}
|
|
|
|
|
return lite_tensors;
|
|
|
|
|
}
|
|
|
|
|
void PrintTensorShape(const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors) {
|
|
|
|
|
int i = 0;
|
|
|
|
|
for (auto input_tensor : input_tensors) {
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
for (auto &dim : input_tensor->shape()) {
|
|
|
|
|
oss << " " << dim;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "input shape " << i++ << ":" << oss.str();
|
|
|
|
|
}
|
|
|
|
|
i = 0;
|
|
|
|
|
for (auto output_tensor : output_tensors) {
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
for (auto &dim : output_tensor->shape()) {
|
|
|
|
|
oss << " " << dim;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "output shape" << i++ << ":" << oss.str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> output_tensors) {
|
|
|
|
|
input_tensors.clear();
|
|
|
|
|
input_tensors.shrink_to_fit();
|
|
|
|
|
output_tensors.clear();
|
|
|
|
|
output_tensors.shrink_to_fit();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
STATUS InferShapePass::Run(MetaGraphT *graph) {
|
|
|
|
|
MS_ASSERT(graph != nullptr);
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type);
|
|
|
|
|
std::vector<Tensor *> output_tensors;
|
|
|
|
|
if (input_tensors.empty() || input_tensors.size() != node->inputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "convert input lite tensor error";
|
|
|
|
|
FreeTensors(input_tensors, output_tensors);
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
auto output_tensors = ConvertTensorToLiteTensor(graph, node->outputIndex, node->primitive->value.type);
|
|
|
|
|
output_tensors = ConvertTensorToLiteTensor(graph, node->outputIndex, node->primitive->value.type);
|
|
|
|
|
if (output_tensors.empty() || output_tensors.size() != node->outputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "convert output lite tensor error";
|
|
|
|
|
FreeTensors(input_tensors, output_tensors);
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<PrimitiveT> primitiveT(new (std::nothrow) PrimitiveT(*node->primitive));
|
|
|
|
|
std::unique_ptr<PrimitiveT> primitiveT(new(std::nothrow) PrimitiveT(*node->primitive));
|
|
|
|
|
if (primitiveT == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "copy primitiveT error";
|
|
|
|
|
FreeTensors(input_tensors, output_tensors);
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto primitiveC = std::shared_ptr<PrimitiveC>(PrimitiveC::Create(primitiveT.release()));
|
|
|
|
|
if (primitiveC == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "unpack primitiveT error";
|
|
|
|
|
FreeTensors(input_tensors, output_tensors);
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto ret = primitiveC->InferShape(input_tensors, output_tensors);
|
|
|
|
|
MS_LOG(DEBUG) << "cur node:" << node->name;
|
|
|
|
|
if (ret == RET_INFER_INVALID) {
|
|
|
|
|
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name
|
|
|
|
|
<< ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type) << "flag set to false.";
|
|
|
|
|
} else if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(WARNING) << "InferShape failed, name: " << node->name
|
|
|
|
|
<< ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type);
|
|
|
|
|
FreeTensors(input_tensors, output_tensors);
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
PrintTensorShape(input_tensors, output_tensors);
|
|
|
|
|
// copy output shape to tensorT
|
|
|
|
|
for (size_t i = 0; i < output_tensors.size(); i++) {
|
|
|
|
|
auto output_dims = output_tensors[i]->shape();
|
|
|
|
@ -100,12 +132,7 @@ STATUS InferShapePass::Run(MetaGraphT *graph) {
|
|
|
|
|
output_tensor->format = output_tensors[i]->GetFormat();
|
|
|
|
|
output_tensor->dataType = output_tensors[i]->data_type();
|
|
|
|
|
}
|
|
|
|
|
for (auto input_tensor : input_tensors) {
|
|
|
|
|
delete input_tensor;
|
|
|
|
|
}
|
|
|
|
|
for (auto output_tensor : output_tensors) {
|
|
|
|
|
delete output_tensor;
|
|
|
|
|
}
|
|
|
|
|
FreeTensors(input_tensors, output_tensors);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|