parent
dd73d18bb7
commit
b123e43bf9
@ -0,0 +1,140 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/nccl_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶ms,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
platform::NCCLContextMap *nccl_ctxs)
|
||||
: loss_var_name_(loss_var_name),
|
||||
places_(places),
|
||||
local_scopes_(local_scopes),
|
||||
nccl_ctxs_(nccl_ctxs) {
|
||||
for (auto &p : params) {
|
||||
grad_names_.insert(GradVarName(p));
|
||||
}
|
||||
}
|
||||
|
||||
void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program,
|
||||
SSAGraph *graph) const {
|
||||
SSAGraph &result = *graph;
|
||||
result.vars_.resize(places_.size());
|
||||
|
||||
bool is_forwarding = true;
|
||||
for (auto *op : program.Block(0).AllOps()) {
|
||||
bool change_forward = false;
|
||||
if (!is_forwarding) {
|
||||
// FIXME(yy): Do not hard code like this
|
||||
if (op->OutputArgumentNames().size() == 1 &&
|
||||
op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) {
|
||||
continue; // Drop fill 1. for backward coeff;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < places_.size(); ++i) {
|
||||
auto &p = places_[i];
|
||||
auto *s = local_scopes_[i];
|
||||
|
||||
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
|
||||
auto *op_handle = result.ops_.back().get();
|
||||
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
|
||||
platform::DeviceContextPool::Instance().Get(p));
|
||||
|
||||
auto var_names = op->InputArgumentNames();
|
||||
|
||||
for (auto &each_var_name : var_names) {
|
||||
VarHandle *var =
|
||||
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
|
||||
op_handle->AddInput(var);
|
||||
}
|
||||
var_names = op->OutputArgumentNames();
|
||||
|
||||
for (auto &each_var_name : var_names) {
|
||||
CreateOpOutput(&result, op_handle, each_var_name, p, i);
|
||||
}
|
||||
|
||||
if (is_forwarding) {
|
||||
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
|
||||
// Insert ScaleCost OpHandle
|
||||
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
|
||||
nccl_ctxs_->DevCtx(p));
|
||||
result.ops_.emplace_back(op_handle);
|
||||
|
||||
// FIXME: Currently ScaleLossGradOp only use device_count as scale
|
||||
// factor. So it does not depend on any other operators.
|
||||
// VarHandle *loss = GetVarHandle(loss_var_name, place);
|
||||
// loss->pending_ops_.emplace_back(op_handle);
|
||||
// op_handle->inputs_.emplace_back(loss);
|
||||
|
||||
CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, i);
|
||||
change_forward = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (change_forward) {
|
||||
is_forwarding = false;
|
||||
}
|
||||
|
||||
if (!is_forwarding) {
|
||||
auto var_names = op->OutputArgumentNames();
|
||||
for (auto &og : var_names) {
|
||||
if (grad_names_.count(og) != 0) { // is param grad
|
||||
// Insert NCCL AllReduce Op
|
||||
result.ops_.emplace_back(
|
||||
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
|
||||
auto *op_handle = result.ops_.back().get();
|
||||
|
||||
for (size_t i = 0; i < places_.size(); ++i) {
|
||||
auto &p = places_[i];
|
||||
auto &vars = result.vars_[i][og];
|
||||
|
||||
if (vars.empty()) { // This device has no data. continue.
|
||||
continue;
|
||||
}
|
||||
auto *prev_grad = &vars[vars.size() - 1];
|
||||
op_handle->AddInput(prev_grad);
|
||||
|
||||
auto &var = vars[vars.size()];
|
||||
var.place_ = p;
|
||||
var.name_ = og;
|
||||
var.version_ = vars.size() - 1;
|
||||
|
||||
op_handle->AddOutput(&var);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Dependency graph has been constructed. However, there are still data
|
||||
harzaeds need to be handled.
|
||||
*/
|
||||
PolishGraphToSupportDataHazards(&result);
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,46 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
class NCCLContextMap;
|
||||
}
|
||||
|
||||
namespace framework {
|
||||
class Scope;
|
||||
namespace details {
|
||||
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
|
||||
public:
|
||||
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶ms,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
platform::NCCLContextMap *nccl_ctxs);
|
||||
|
||||
void Build(const ProgramDesc &program, SSAGraph *graph) const override;
|
||||
|
||||
private:
|
||||
std::string loss_var_name_;
|
||||
const std::vector<platform::Place> &places_;
|
||||
const std::vector<Scope *> &local_scopes_;
|
||||
platform::NCCLContextMap *nccl_ctxs_;
|
||||
std::unordered_set<std::string> grad_names_;
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,88 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
|
||||
for (auto &var_map : graph->vars_) {
|
||||
for (auto &name_pair : var_map) {
|
||||
if (name_pair.second.size() <= 1) {
|
||||
return;
|
||||
}
|
||||
auto it_new = name_pair.second.rbegin();
|
||||
auto it_old = name_pair.second.rbegin();
|
||||
++it_old;
|
||||
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
|
||||
auto *write_op = it_new->second.generated_op_;
|
||||
auto &read_ops = it_old->second.pending_ops_;
|
||||
auto *ex_write_op = it_old->second.generated_op_;
|
||||
|
||||
if (ex_write_op == nullptr) { // Nobody write this var.
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto *read_op : read_ops) {
|
||||
// Manually add a dependency var from read_op to write_op;
|
||||
if (read_op == write_op) {
|
||||
// Read Write is the same op.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *dep_var = new DummyVarHandle();
|
||||
read_op->AddOutput(dep_var);
|
||||
write_op->AddInput(dep_var);
|
||||
graph->dep_vars_.emplace(dep_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
|
||||
SSAGraph *graph, const std::string &each_var_name,
|
||||
const platform::Place &place, size_t place_offset) {
|
||||
auto &var_holders = graph->vars_[place_offset];
|
||||
auto &var_holder = var_holders[each_var_name];
|
||||
VarHandle *var = nullptr;
|
||||
if (var_holder.empty()) {
|
||||
auto &init_var = var_holder[0];
|
||||
init_var.place_ = place;
|
||||
init_var.name_ = each_var_name;
|
||||
init_var.generated_op_ = nullptr;
|
||||
init_var.version_ = 0;
|
||||
var = &init_var;
|
||||
} else {
|
||||
var = &var_holder.rbegin()->second;
|
||||
}
|
||||
return var;
|
||||
}
|
||||
|
||||
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
|
||||
const std::string &each_var_name,
|
||||
const platform::Place &place,
|
||||
size_t place_offset) {
|
||||
auto &vars = graph->vars_[place_offset][each_var_name];
|
||||
size_t version = vars.size();
|
||||
auto &var = vars[version];
|
||||
var.version_ = version;
|
||||
var.name_ = each_var_name;
|
||||
var.place_ = place;
|
||||
op_handle->AddOutput(&var);
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class SSAGraphBuilder {
|
||||
public:
|
||||
SSAGraphBuilder() {}
|
||||
virtual ~SSAGraphBuilder() {}
|
||||
virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
|
||||
|
||||
protected:
|
||||
/**
|
||||
* We only handle write after read(WAR), since it should not have a write
|
||||
* after write in program. If there are write after write operators, we need
|
||||
* prune them.
|
||||
*
|
||||
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
|
||||
*/
|
||||
static void PolishGraphToSupportDataHazards(SSAGraph *graph);
|
||||
|
||||
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
|
||||
const std::string &each_var_name,
|
||||
const platform::Place &place,
|
||||
size_t place_offset);
|
||||
|
||||
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
|
||||
const std::string &each_var_name,
|
||||
const platform::Place &place, size_t place_offset);
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue