|
|
|
@ -14,15 +14,17 @@
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/analysis_pass.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/flags.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/pass.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace inference {
|
|
|
|
|
namespace analysis {
|
|
|
|
|
using namespace framework;
|
|
|
|
|
|
|
|
|
|
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
|
|
|
|
|
|
|
|
|
@ -48,7 +50,8 @@ class FluidToIrPass final : public DataFlowGraphPass {
|
|
|
|
|
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
|
|
|
|
|
// Load program.
|
|
|
|
|
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
|
|
|
|
|
argument->origin_program_desc.reset(new proto::ProgramDesc(program));
|
|
|
|
|
argument->origin_program_desc.reset(
|
|
|
|
|
new framework::proto::ProgramDesc(program));
|
|
|
|
|
// Create main data flow graph.
|
|
|
|
|
if (!argument->main_dfg) {
|
|
|
|
|
argument->main_dfg.reset(new DataFlowGraph);
|
|
|
|
@ -78,12 +81,13 @@ class FluidToIrPass final : public DataFlowGraphPass {
|
|
|
|
|
IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"),
|
|
|
|
|
nullptr);
|
|
|
|
|
// Pass the scope from analysis to IR if needed.
|
|
|
|
|
if (argument_->Has(ir::kParamScopeAttr)) {
|
|
|
|
|
if (argument_->Has(framework::ir::kParamScopeAttr)) {
|
|
|
|
|
// Here the address is passed, attention that IR doesn't own the scope, so
|
|
|
|
|
// the real scope in analysis should live during the IR phase.
|
|
|
|
|
ir_passes.graph().Set(
|
|
|
|
|
ir::kParamScopeAttr,
|
|
|
|
|
new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr)));
|
|
|
|
|
framework::ir::kParamScopeAttr,
|
|
|
|
|
new framework::Scope *(&argument_->Get<framework::Scope>(
|
|
|
|
|
framework::ir::kParamScopeAttr)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_IA_enable_ir) {
|
|
|
|
@ -95,12 +99,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
|
|
|
|
|
PADDLE_ENFORCE(argument_->main_dfg.get());
|
|
|
|
|
argument_->main_dfg->Build(ir_passes.graph());
|
|
|
|
|
// inherit the arguments from ir.
|
|
|
|
|
if (ir_passes.graph().Has(ir::kFuseStatisAttr)) {
|
|
|
|
|
if (ir_passes.graph().Has(framework::ir::kFuseStatisAttr)) {
|
|
|
|
|
argument_->Set(
|
|
|
|
|
ir::kFuseStatisAttr,
|
|
|
|
|
framework::ir::kFuseStatisAttr,
|
|
|
|
|
new std::unordered_map<std::string, int>(
|
|
|
|
|
ir_passes.graph().Get<std::unordered_map<std::string, int>>(
|
|
|
|
|
ir::kFuseStatisAttr)));
|
|
|
|
|
framework::ir::kFuseStatisAttr)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -112,7 +116,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// Load parameters from a single file or from a directory.
|
|
|
|
|
bool LoadParams(Scope *scope, const std::string &dir,
|
|
|
|
|
bool LoadParams(framework::Scope *scope, const std::string &dir,
|
|
|
|
|
const std::string &prog_file, const std::string ¶m_file);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|