|
|
|
@ -14,9 +14,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/build_strategy.h"
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
@ -71,46 +68,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
|
|
|
|
AppendPass("multi_devices_check_pass");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> Build(
|
|
|
|
|
const ProgramDesc &main_program,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
|
const std::unordered_set<std::string> ¶m_names,
|
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
|
|
|
|
|
#else
|
|
|
|
|
const bool use_cuda) const {
|
|
|
|
|
#endif
|
|
|
|
|
// Convert the program to graph.
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
|
|
|
|
|
|
|
|
|
|
for (std::shared_ptr<ir::Pass> &pass : AllPasses()) {
|
|
|
|
|
if (pass->Type() == "multi_devices_pass") {
|
|
|
|
|
pass->SetNotOwned<const std::vector<platform::Place>>("places",
|
|
|
|
|
&places);
|
|
|
|
|
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
|
|
|
|
|
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
|
|
|
|
|
¶m_names);
|
|
|
|
|
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
|
|
|
|
|
&local_scopes);
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
|
|
|
|
|
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
graph = pass->Apply(std::move(graph));
|
|
|
|
|
}
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
BuildStrategy strategy_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
ir::PassBuilder *BuildStrategy::CreatePassBuilder() const {
|
|
|
|
|
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
|
|
|
|
|
const {
|
|
|
|
|
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
|
|
|
|
|
return pass_builder_.get();
|
|
|
|
|
return pass_builder_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
|
|
|
|
@ -123,20 +88,33 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
|
|
|
|
|
#else
|
|
|
|
|
const bool use_cuda) const {
|
|
|
|
|
#endif
|
|
|
|
|
// Create a default one if not intialized by user.
|
|
|
|
|
if (!pass_builder_) {
|
|
|
|
|
CreatePassBuilder();
|
|
|
|
|
CreatePassesFromStrategy();
|
|
|
|
|
}
|
|
|
|
|
// std::unique_ptr<ir::Graph> graph;
|
|
|
|
|
ParallelExecutorPassBuilder *builder =
|
|
|
|
|
reinterpret_cast<ParallelExecutorPassBuilder *>(pass_builder_.get());
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
|
|
|
|
|
|
|
|
|
|
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
|
|
|
|
|
if (pass->Type() == "multi_devices_pass") {
|
|
|
|
|
pass->Erase("places");
|
|
|
|
|
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
|
|
|
|
|
pass->Erase("loss_var_name");
|
|
|
|
|
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
|
|
|
|
|
pass->Erase("params");
|
|
|
|
|
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
|
|
|
|
|
¶m_names);
|
|
|
|
|
pass->Erase("local_scopes");
|
|
|
|
|
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
|
|
|
|
|
&local_scopes);
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
std::unique_ptr<ir::Graph> graph =
|
|
|
|
|
builder->Build(main_program, places, loss_var_name, param_names,
|
|
|
|
|
local_scopes, use_cuda, nccl_ctxs);
|
|
|
|
|
#else
|
|
|
|
|
std::unique_ptr<ir::Graph> graph = builder->Build(
|
|
|
|
|
main_program, places, loss_var_name, param_names, local_scopes, use_cuda);
|
|
|
|
|
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
|
|
|
|
|
pass->Erase("nccl_ctxs");
|
|
|
|
|
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
graph = pass->Apply(std::move(graph));
|
|
|
|
|
}
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
} // namespace details
|
|
|
|
|