|
|
|
@ -204,15 +204,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
|
|
|
|
|
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const size_t &nranks,
|
|
|
|
|
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const size_t &nranks,
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
|
|
|
|
|
const bool use_cuda,
|
|
|
|
|
platform::NCCLContextMap *nccl_ctxs) const {
|
|
|
|
|
#else
|
|
|
|
|
const bool use_cuda) const {
|
|
|
|
|
const bool use_cuda) const {
|
|
|
|
|
#endif
|
|
|
|
|
// Create a default one if not finalized by user.
|
|
|
|
|
CreatePassesFromStrategy(false);
|
|
|
|
@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "Start Apply Pass " << pass->Type();
|
|
|
|
|
graph = pass->Apply(std::move(graph));
|
|
|
|
|
graph = pass->Apply(graph);
|
|
|
|
|
VLOG(3) << "Finish Apply Pass " << pass->Type();
|
|
|
|
|
}
|
|
|
|
|
return graph;
|
|
|
|
|