Add backward and optimizer operator dependency pass. (#17746)
parent
4cb7d32c9b
commit
fbbdc9ccad
@ -0,0 +1,223 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/details/container_cast.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class BackWardOpDepsPass : public ir::Pass {
|
||||
protected:
|
||||
void AddDep(ir::Graph* graph, details::OpHandleBase* l,
|
||||
details::OpHandleBase* r) const {
|
||||
auto* dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
|
||||
graph->Get<details::GraphDepVars>(details::kGraphDepVars).emplace(dep_var);
|
||||
l->AddOutput(dep_var);
|
||||
r->AddInput(dep_var);
|
||||
VLOG(10) << "add deps:" << l->DebugString() << " and " << r->DebugString();
|
||||
}
|
||||
|
||||
void ApplyImpl(ir::Graph* graph) const override {
|
||||
// NOTE: The operator nodes should be in topology order.
|
||||
std::vector<details::OpHandleBase*> backward_op_handles;
|
||||
std::vector<details::OpHandleBase*> all_opt_handles;
|
||||
details::ParamsAndGrads params_grads;
|
||||
std::vector<ir::Node*> topo_nodes = ir::TopologySortOperations(*graph);
|
||||
for (auto& node : topo_nodes) {
|
||||
if (!node->Op()) continue;
|
||||
|
||||
GetBackWardOpHandles(node, &backward_op_handles, ¶ms_grads);
|
||||
GetOptimizerOpHandles(node, &all_opt_handles);
|
||||
}
|
||||
|
||||
VLOG(10) << "backward_op_handles size:" << backward_op_handles.size()
|
||||
<< ", opt_handles size:" << all_opt_handles.size();
|
||||
|
||||
if (backward_op_handles.size() <= 1 || all_opt_handles.size() <= 1) {
|
||||
VLOG(10) << "need not backward_op_deps_pass";
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<details::OpHandleBase*> opt_handles;
|
||||
GetOptimizerHandlesRoot(all_opt_handles, &opt_handles, params_grads);
|
||||
|
||||
if (opt_handles.size() <= 1) {
|
||||
VLOG(10) << "need not backward_op_deps_pass";
|
||||
return;
|
||||
}
|
||||
|
||||
VLOG(10) << "add optimize deps";
|
||||
for (size_t i = 1; i < opt_handles.size(); ++i) {
|
||||
AddDep(graph, opt_handles[i - 1], opt_handles[i]);
|
||||
}
|
||||
|
||||
VLOG(10) << "add deps between backward and optimze:";
|
||||
AddDep(graph, backward_op_handles[backward_op_handles.size() - 1],
|
||||
opt_handles[0]);
|
||||
}
|
||||
|
||||
/*
|
||||
* When the backward ophandles complete, the optimizer ophandle's inputs var
|
||||
* are ready.Since the optimizer ophandles can be seen as graphs which each of
|
||||
* them doesn't connect to each other, they can run parallelly or by a
|
||||
* specified order, such as by the grads generated order. This function will
|
||||
* get these graphs' root.
|
||||
*/
|
||||
void GetOptimizerHandlesRoot(
|
||||
const std::vector<details::OpHandleBase*>& ops,
|
||||
std::vector<details::OpHandleBase*>* result,
|
||||
const details::ParamsAndGrads& params_grads) const {
|
||||
std::unordered_set<details::OpHandleBase*> visit;
|
||||
for (auto op : ops) {
|
||||
if (visit.find(op) != visit.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(10) << "visiting all_opt_handles:" << op->DebugString();
|
||||
|
||||
result->emplace_back(op);
|
||||
visit.insert(op);
|
||||
VisitChildrens(op, &visit);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < result->size(); i++) {
|
||||
VLOG(10) << "get potential head op:" << (*result)[i]->DebugString();
|
||||
}
|
||||
|
||||
// sort by param_grad order
|
||||
std::unordered_map<std::string, int> pg_order;
|
||||
int order = 0;
|
||||
for (auto& p_g : params_grads) {
|
||||
pg_order[p_g.second] = order++;
|
||||
}
|
||||
|
||||
std::vector<std::pair<details::OpHandleBase*, int>> op_handles;
|
||||
for (auto op : *result) {
|
||||
int order = 0;
|
||||
for (auto input : op->Inputs()) {
|
||||
if (dynamic_cast<details::VarHandle*>(input) == nullptr) continue;
|
||||
|
||||
if (pg_order.find(input->Name()) == pg_order.end()) {
|
||||
VLOG(10) << "not find input " << input->Name() << " in grad";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (order < pg_order.at(input->Name())) {
|
||||
order = pg_order.at(input->Name());
|
||||
}
|
||||
}
|
||||
op_handles.emplace_back(std::make_pair(op, order));
|
||||
}
|
||||
|
||||
sort(op_handles.begin(), op_handles.end(),
|
||||
[](const std::pair<details::OpHandleBase*, int>& left,
|
||||
const std::pair<details::OpHandleBase*, int>& right) -> bool {
|
||||
return left.second < right.second;
|
||||
});
|
||||
|
||||
result->clear();
|
||||
for (auto p : op_handles) {
|
||||
result->emplace_back(p.first);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < result->size(); i++) {
|
||||
VLOG(10) << "get head op:" << (*result)[i]->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
void VisitChildrens(details::OpHandleBase* op,
|
||||
std::unordered_set<details::OpHandleBase*>* visit) const {
|
||||
for (auto out : op->Outputs()) {
|
||||
for (auto* pending_op : out->PendingOps()) {
|
||||
if (visit->find(pending_op) != visit->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(10) << "visiting:" << pending_op->DebugString();
|
||||
|
||||
visit->insert(pending_op);
|
||||
VisitChildrens(pending_op, visit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetBackWardOpHandles(
|
||||
ir::Node* node, std::vector<details::OpHandleBase*>* backward_op_handles,
|
||||
details::ParamsAndGrads* params_grads) const {
|
||||
try {
|
||||
bool is_bk_op =
|
||||
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
|
||||
OpProtoAndCheckerMaker::OpRoleAttrName())) &
|
||||
static_cast<int>(OpRole::kBackward));
|
||||
if (!is_bk_op) return;
|
||||
|
||||
// Currently, we assume that once gradient is generated, it can be
|
||||
// broadcast, and each gradient is only broadcast once.
|
||||
auto backward_vars =
|
||||
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
|
||||
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
||||
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
|
||||
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
|
||||
|
||||
backward_op_handles->emplace_back(
|
||||
&node->Wrapper<details::OpHandleBase>());
|
||||
|
||||
for (size_t i = 0; i < backward_vars.size(); i += 2) {
|
||||
VLOG(10) << "Trainable parameter: " << backward_vars[i]
|
||||
<< ", gradient: " << backward_vars[i + 1];
|
||||
|
||||
params_grads->emplace_back(std::make_pair(
|
||||
backward_vars[i] /*param*/, backward_vars[i + 1] /*grad*/));
|
||||
}
|
||||
} catch (boost::bad_get e) {
|
||||
}
|
||||
}
|
||||
|
||||
void GetOptimizerOpHandles(
|
||||
ir::Node* node, std::vector<details::OpHandleBase*>* opt_handles) const {
|
||||
try {
|
||||
bool is_opt_op =
|
||||
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
|
||||
OpProtoAndCheckerMaker::OpRoleAttrName())) &
|
||||
static_cast<int>(OpRole::kOptimize));
|
||||
if (!is_opt_op) return;
|
||||
|
||||
opt_handles->emplace_back(&node->Wrapper<details::OpHandleBase>());
|
||||
} catch (boost::bad_get e) {
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(backward_optimizer_op_deps_pass,
|
||||
paddle::framework::ir::BackWardOpDepsPass);
|
Loading…
Reference in new issue