|
|
|
@ -19,12 +19,12 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
std::unique_ptr<ir::Graph> graph) {
|
|
|
|
|
std::vector<std::unique_ptr<ir::Graph>>
|
|
|
|
|
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
|
|
|
|
|
std::unique_ptr<ir::Graph> &&graph) {
|
|
|
|
|
std::vector<std::unique_ptr<ir::Graph>> graphs;
|
|
|
|
|
graphs.reserve(places.size());
|
|
|
|
|
for (size_t i = 0; i < places.size(); ++i) {
|
|
|
|
|
graphs.reserve(places_.size());
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
ProgramDesc empty;
|
|
|
|
|
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
|
|
|
|
|
auto &g = graphs.back();
|
|
|
|
@ -60,7 +60,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) {
|
|
|
|
|
for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) {
|
|
|
|
|
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
|
|
|
|
|
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
|
|
|
|
|
for (auto &name_pair : origin_vars) {
|
|
|
|
@ -80,14 +80,26 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
|
|
|
|
|
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
|
|
|
|
|
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
std::vector<std::unique_ptr<ir::Graph>> &&graphs)
|
|
|
|
|
const framework::ProgramDesc &main_prog, std::unique_ptr<ir::Graph> &&graph)
|
|
|
|
|
: strategy_(std::move(strategy)),
|
|
|
|
|
local_scopes_(std::move(local_scopes)),
|
|
|
|
|
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
|
|
|
|
|
places_(std::move(places)),
|
|
|
|
|
graphs_(std::move(graphs)) {
|
|
|
|
|
main_prog_(main_prog),
|
|
|
|
|
// TODO(Yancey1989): copy graphs is not safely since it deleted the attrs.
|
|
|
|
|
graphs_(SeparateMultiDevicesGraph(std::move(graph))) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
|
|
|
|
|
|
|
|
|
|
auto seq_allreduce_pass =
|
|
|
|
|
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
|
|
|
|
|
seq_allreduce_pass->Erase(details::kAllOpDescs);
|
|
|
|
|
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
|
|
|
|
|
details::kAllOpDescs,
|
|
|
|
|
new std::vector<OpDesc *>(main_prog_.Block(0).AllOps()));
|
|
|
|
|
for (size_t i = 0; i < graphs_.size(); ++i) {
|
|
|
|
|
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set the correct size of thread pool to each device.
|
|
|
|
|
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
|
|
|
|
|
? 1UL
|
|
|
|
|