|
|
|
@ -306,20 +306,20 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
|
|
|
|
MS_ASSERT(nullptr != old_graph);
|
|
|
|
|
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
|
|
|
|
MS_ASSERT(nullptr != main_graph);
|
|
|
|
|
if (config == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "config should be specified";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto status = RunPrecedingPass(old_graph, *config);
|
|
|
|
|
auto status = RunPrecedingPass(main_graph, *config);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Run Preceding pass failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
status = RunAdjustPass(old_graph, config);
|
|
|
|
|
status = RunAdjustPass(main_graph, config);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Run Adjust pass failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
@ -357,59 +357,17 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto new_graph = optimizer->Optimize(old_graph);
|
|
|
|
|
auto new_graph = optimizer->Optimize(main_graph);
|
|
|
|
|
if (new_graph == nullptr) {
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
status = DoQuantize(old_graph, config, new_graph);
|
|
|
|
|
status = DoQuantize(main_graph, config, new_graph);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Do Quantize failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return new_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) {
|
|
|
|
|
if (func_graphs_.find(func_graph) == func_graphs_.end()) {
|
|
|
|
|
func_graphs_.insert(func_graph);
|
|
|
|
|
} else {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto nodes = func_graph->nodes();
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (IsValueNode<FuncGraph>(node)) {
|
|
|
|
|
auto new_fg = (node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
|
|
|
|
|
GetAllFuncGraph(new_fg);
|
|
|
|
|
}
|
|
|
|
|
if (utils::isa<CNodePtr>(node)) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
for (auto &input : cnode->inputs()) {
|
|
|
|
|
if (input->isa<ValueNode>()) {
|
|
|
|
|
if (IsValueNode<FuncGraph>(input)) {
|
|
|
|
|
auto new_fg = (input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
|
|
|
|
|
GetAllFuncGraph(new_fg);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
|
|
|
|
GetAllFuncGraph(main_graph);
|
|
|
|
|
|
|
|
|
|
for (auto &fg : func_graphs_) {
|
|
|
|
|
auto new_main_graph = TransformSingleFuncGraph(fg, config);
|
|
|
|
|
if (new_main_graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "TransformSingleFuncGraph failed.";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return main_graph;
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore::lite
|
|
|
|
|