|
|
|
@ -28,6 +28,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
using details::ComputationOpHandle;
|
|
|
|
|
using details::DummyVarHandle;
|
|
|
|
|
using details::FetchOpHandle;
|
|
|
|
|
using details::NCCLAllReduceOpHandle;
|
|
|
|
@ -35,7 +36,6 @@ using details::OpHandleBase;
|
|
|
|
|
using details::ScaleLossGradOpHandle;
|
|
|
|
|
using details::VarHandle;
|
|
|
|
|
using details::VarHandleBase;
|
|
|
|
|
using details::ComputationOpHandle;
|
|
|
|
|
|
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
|
public:
|
|
|
|
@ -43,7 +43,9 @@ class ParallelExecutorPrivate {
|
|
|
|
|
const std::vector<platform::Place> &places)
|
|
|
|
|
: places_(places),
|
|
|
|
|
fetch_dev_ctxs_(places),
|
|
|
|
|
pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
|
|
|
|
|
pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {
|
|
|
|
|
vars_.resize(places.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
|
|
platform::DeviceContextPool fetch_dev_ctxs_;
|
|
|
|
@ -52,12 +54,7 @@ class ParallelExecutorPrivate {
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
|
|
|
|
|
|
|
|
|
|
platform::Place main_place_;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<platform::Place,
|
|
|
|
|
std::unordered_map<std::string, std::map<int, VarHandle>>,
|
|
|
|
|
platform::PlaceHash>
|
|
|
|
|
vars_;
|
|
|
|
|
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
|
|
|
|
|
|
|
|
|
@ -69,8 +66,8 @@ class ParallelExecutorPrivate {
|
|
|
|
|
std::unique_ptr<platform::EnforceNotMet> exception_;
|
|
|
|
|
|
|
|
|
|
VarHandle *GetVarHandle(const std::string &each_var_name,
|
|
|
|
|
const platform::Place &place) {
|
|
|
|
|
auto &var_holders = vars_[place];
|
|
|
|
|
const platform::Place &place, size_t place_offset) {
|
|
|
|
|
auto &var_holders = vars_[place_offset];
|
|
|
|
|
auto &var_holder = var_holders[each_var_name];
|
|
|
|
|
VarHandle *var = nullptr;
|
|
|
|
|
if (var_holder.empty()) {
|
|
|
|
@ -118,8 +115,8 @@ class ParallelExecutorPrivate {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GenerateVar(OpHandleBase *op_handle, const std::string &each_var_name,
|
|
|
|
|
const platform::Place &place) {
|
|
|
|
|
auto &vars = vars_[place][each_var_name];
|
|
|
|
|
const platform::Place &place, size_t place_offset) {
|
|
|
|
|
auto &vars = vars_[place_offset][each_var_name];
|
|
|
|
|
size_t version = vars.size();
|
|
|
|
|
auto &var = vars[version];
|
|
|
|
|
var.version_ = version;
|
|
|
|
@ -144,11 +141,10 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
for (size_t i = 0; i < member_->places_.size(); ++i) {
|
|
|
|
|
member_->local_scopes_.push_back(&scope->NewScope());
|
|
|
|
|
}
|
|
|
|
|
member_->main_place_ = places[0];
|
|
|
|
|
|
|
|
|
|
// Bcast Parameters to all GPUs
|
|
|
|
|
BuildNCCLCommunicator();
|
|
|
|
|
if (platform::is_gpu_place(member_->main_place_) &&
|
|
|
|
|
if (platform::is_gpu_place(places[0]) &&
|
|
|
|
|
member_->local_scopes_.size() != 1) { // Is CUDA
|
|
|
|
|
BCastParamsToGPUs(startup_program);
|
|
|
|
|
}
|
|
|
|
@ -201,13 +197,13 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
auto var_names = op->InputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
VarHandle *var = member_->GetVarHandle(each_var_name, p);
|
|
|
|
|
VarHandle *var = member_->GetVarHandle(each_var_name, p, i);
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
member_->GenerateVar(op_handle, each_var_name, p);
|
|
|
|
|
member_->GenerateVar(op_handle, each_var_name, p, i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_forwarding) {
|
|
|
|
@ -224,7 +220,7 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
// loss->pending_ops_.emplace_back(op_handle);
|
|
|
|
|
// op_handle->inputs_.emplace_back(loss);
|
|
|
|
|
|
|
|
|
|
member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p);
|
|
|
|
|
member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p, i);
|
|
|
|
|
change_forward = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -245,7 +241,7 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < member_->places_.size(); ++i) {
|
|
|
|
|
auto &p = member_->places_[i];
|
|
|
|
|
auto &vars = member_->vars_[p][og];
|
|
|
|
|
auto &vars = member_->vars_[i][og];
|
|
|
|
|
|
|
|
|
|
if (vars.empty()) { // This device has no data. continue.
|
|
|
|
|
continue;
|
|
|
|
@ -280,8 +276,8 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
|
|
|
|
|
*/
|
|
|
|
|
void ParallelExecutor::PolishGraphToSupportDataHazards() const {
|
|
|
|
|
for (auto &place_pair : member_->vars_) {
|
|
|
|
|
for (auto &name_pair : place_pair.second) {
|
|
|
|
|
for (auto &var_map : member_->vars_) {
|
|
|
|
|
for (auto &name_pair : var_map) {
|
|
|
|
|
if (name_pair.second.size() <= 1) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -369,8 +365,8 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> pending_ops;
|
|
|
|
|
std::vector<DummyVarHandle> dummy_vars;
|
|
|
|
|
|
|
|
|
|
for (auto &place_pair : member_->vars_) {
|
|
|
|
|
for (auto &name_pair : place_pair.second) {
|
|
|
|
|
for (auto &var_map : member_->vars_) {
|
|
|
|
|
for (auto &name_pair : var_map) {
|
|
|
|
|
for (auto &version_pair : name_pair.second) {
|
|
|
|
|
pending_vars[&version_pair.second] =
|
|
|
|
|
version_pair.second.generated_op_ == nullptr;
|
|
|
|
@ -395,9 +391,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
|
|
|
|
|
|
|
|
|
for (auto &fetch_var_name : fetch_tensors) {
|
|
|
|
|
for (auto &pair : member_->vars_) {
|
|
|
|
|
auto it = pair.second.find(fetch_var_name);
|
|
|
|
|
if (it != pair.second.end()) {
|
|
|
|
|
for (auto &var_map : member_->vars_) {
|
|
|
|
|
auto it = var_map.find(fetch_var_name);
|
|
|
|
|
if (it != var_map.end()) {
|
|
|
|
|
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|