|
|
|
@ -142,6 +142,14 @@ void IRPassManager::CreatePasses(Argument *argument,
|
|
|
|
|
disable_logs_ = argument->disable_logs();
|
|
|
|
|
if (pass_name == "fc_fuse_pass") {
|
|
|
|
|
pass->Set("use_gpu", new bool(argument->use_gpu()));
|
|
|
|
|
bool fc_mkldnn_pass = 0;
|
|
|
|
|
for (const std::string &pass_n : passes) {
|
|
|
|
|
if (pass_n == "fc_mkldnn_pass") {
|
|
|
|
|
fc_mkldnn_pass = 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
|
|
|
|
|
pass->Set("use_fc_padding", new bool(use_fc_padding));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pre_pass = pass_name;
|
|
|
|
@ -150,47 +158,12 @@ void IRPassManager::CreatePasses(Argument *argument,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IRPassManager::HasPass(const std::string &pass_type) {
|
|
|
|
|
if (passes_.empty()) return false;
|
|
|
|
|
auto it = std::find_if(
|
|
|
|
|
passes_.begin(), passes_.end(),
|
|
|
|
|
[&](std::unique_ptr<Pass> &pass) { return pass->Type() == pass_type; });
|
|
|
|
|
return it != passes_.end();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Pass> &IRPassManager::GetPass(const std::string &pass_type) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(passes_.empty(), false,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The list of passes cannot be empty."));
|
|
|
|
|
auto it = std::find_if(passes_.begin(), passes_.end(),
|
|
|
|
|
[&](const std::unique_ptr<Pass> &pass) {
|
|
|
|
|
return pass->Type() == pass_type;
|
|
|
|
|
});
|
|
|
|
|
PADDLE_ENFORCE_NE(it, passes_.end(),
|
|
|
|
|
platform::errors::PermissionDenied(
|
|
|
|
|
"You cannot get pass which was not added earlier."));
|
|
|
|
|
return *it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Some passes depend on each other. This method serves for exchanging
|
|
|
|
|
// information between them.
|
|
|
|
|
void IRPassManager::UpdatePasses() {
|
|
|
|
|
// Update padding settings for fc_fuse_pass. Skipp adding padding for
|
|
|
|
|
// MKL-DNN-based FC
|
|
|
|
|
bool use_fc_padding = !HasPass("fc_mkldnn_pass");
|
|
|
|
|
if (HasPass("fc_fuse_pass")) {
|
|
|
|
|
auto &fc_fuse_pass = GetPass("fc_fuse_pass");
|
|
|
|
|
fc_fuse_pass->Set<bool>("use_fc_padding", new bool(use_fc_padding));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
|
|
|
|
|
if (passes_.empty()) {
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(graph.get(), platform::errors::PreconditionNotMet(
|
|
|
|
|
"Graph cannot be NULL."));
|
|
|
|
|
UpdatePasses();
|
|
|
|
|
// Apply all the passes
|
|
|
|
|
for (const auto &pass : passes_) {
|
|
|
|
|
if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
|
|
|
|
|