fea/infer memory optim2 (#14953)
parent
6597ccb01f
commit
885c4e57ab
@ -1,11 +1,18 @@
|
||||
cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager)
|
||||
cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager)
|
||||
cc_library(memory_optim_pass SRCS memory_optimize_pass.cc DEPS analysis_pass)
|
||||
cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager)
|
||||
cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass)
|
||||
cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass)
|
||||
|
||||
cc_library(analysis_passes SRCS passes.cc DEPS
|
||||
ir_graph_build_pass
|
||||
ir_analysis_pass
|
||||
ir_params_sync_among_devices_pass
|
||||
memory_optim_pass
|
||||
ir_graph_to_program_pass
|
||||
)
|
||||
|
||||
set(analysis_deps ${analysis_deps}
|
||||
ir_graph_build_pass
|
||||
ir_analysis_pass
|
||||
analysis_passes
|
||||
subgraph_detector
|
||||
CACHE INTERNAL "")
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
|
||||
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
|
||||
#include "paddle/fluid/string/pretty_log.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
void IrAnalysisComposePass::RunImpl(Argument *argument) {
|
||||
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
|
||||
ApplyIrPasses(argument);
|
||||
CollectFusionStatis(argument);
|
||||
}
|
||||
|
||||
std::string IrAnalysisComposePass::repr() const {
|
||||
return "ir-analysis-compose-pass";
|
||||
}
|
||||
|
||||
void IrAnalysisComposePass::ApplyIrPasses(Argument *argument) {
|
||||
std::vector<std::string> passes({
|
||||
"ir_graph_build_pass", "ir_analysis_pass",
|
||||
"ir_params_sync_among_devices_pass",
|
||||
});
|
||||
for (const auto &pass : passes) {
|
||||
VLOG(2) << "Run pass " << pass;
|
||||
auto *the_pass = PassRegistry::Global().Retreive(pass);
|
||||
the_pass->Run(argument);
|
||||
}
|
||||
}
|
||||
|
||||
void IrAnalysisComposePass::CollectFusionStatis(Argument *argument) {
|
||||
if (!argument->main_graph().Has(framework::ir::kFuseStatisAttr)) {
|
||||
LOG(INFO) << "argument has no fuse statis";
|
||||
return;
|
||||
}
|
||||
argument->SetFusionStatis(
|
||||
argument->main_graph().Get<Argument::fusion_statis_t>(
|
||||
framework::ir::kFuseStatisAttr));
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
|
||||
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
void IrGraphToProgramPass::RunImpl(Argument *argument) {
|
||||
auto pass =
|
||||
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
|
||||
|
||||
if (argument->memory_optim_sort_kind_valid()) {
|
||||
pass->Set(framework::ir::kGraphToProgramSortKind,
|
||||
new int(argument->memory_optim_sort_kind()));
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> graph(argument->main_graph_ptr());
|
||||
framework::ProgramDesc desc(argument->main_program());
|
||||
pass->SetNotOwned("program", &desc);
|
||||
auto thegraph = pass->Apply(std::move(graph));
|
||||
thegraph.release(); // the argument still own the graph.
|
||||
|
||||
argument->SetIrAnalyzedProgram(
|
||||
new framework::proto::ProgramDesc(*desc.Proto()));
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue