commit
869487a2b7
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/details/op_graph_view.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static bool IsLockAndRecordEventFreeComputationOpHandle(
|
||||
ComputationOpHandle *op, const OpGraphView &graph_view) {
|
||||
if (!platform::is_gpu_place(op->GetPlace())) return false;
|
||||
for (auto &pending_op : graph_view.PendingOps(op)) {
|
||||
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
|
||||
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> ir_graph) const {
|
||||
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
|
||||
OpGraphView graph_view(all_ops);
|
||||
for (auto &op : all_ops) {
|
||||
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
|
||||
if (compute_op == nullptr) continue;
|
||||
bool is_lock_and_record_event_free =
|
||||
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
|
||||
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
|
||||
if (is_lock_and_record_event_free) {
|
||||
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
|
||||
<< compute_op->DebugString();
|
||||
}
|
||||
}
|
||||
return ir_graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(modify_op_lock_and_record_event_pass,
|
||||
paddle::framework::details::ModifyOpLockAndRecordEventPass);
|
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class ModifyOpLockAndRecordEventPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,77 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/op_graph_view.h"
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
OpGraphView::OpGraphView(
|
||||
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
|
||||
Build(ops);
|
||||
}
|
||||
|
||||
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
|
||||
for (auto &op : ops) {
|
||||
preceding_ops_[op.get()];
|
||||
pending_ops_[op.get()];
|
||||
for (auto &var : op->Outputs()) {
|
||||
for (auto &pending_op : var->PendingOps()) {
|
||||
preceding_ops_[pending_op].insert(op.get());
|
||||
pending_ops_[op.get()].insert(pending_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
PADDLE_ENFORCE(
|
||||
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
|
||||
"There are duplicate ops in graph.");
|
||||
}
|
||||
|
||||
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
|
||||
|
||||
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
|
||||
std::unordered_set<OpHandleBase *> ret;
|
||||
for (auto &pair : preceding_ops_) {
|
||||
ret.insert(pair.first);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool OpGraphView::HasOp(OpHandleBase *op) const {
|
||||
return preceding_ops_.count(op) != 0;
|
||||
}
|
||||
|
||||
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
|
||||
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
|
||||
op == nullptr ? "nullptr" : op->DebugString());
|
||||
}
|
||||
|
||||
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
|
||||
OpHandleBase *op) const {
|
||||
EnforceHasOp(op);
|
||||
return preceding_ops_.at(op);
|
||||
}
|
||||
|
||||
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
|
||||
OpHandleBase *op) const {
|
||||
EnforceHasOp(op);
|
||||
return pending_ops_.at(op);
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,54 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class OpGraphView {
|
||||
public:
|
||||
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
|
||||
|
||||
size_t OpNumber() const;
|
||||
|
||||
std::unordered_set<OpHandleBase *> AllOps() const;
|
||||
|
||||
const std::unordered_set<OpHandleBase *> &PrecedingOps(
|
||||
OpHandleBase *op) const;
|
||||
|
||||
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
|
||||
|
||||
bool HasOp(OpHandleBase *op) const;
|
||||
|
||||
private:
|
||||
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
|
||||
void EnforceHasOp(OpHandleBase *op) const;
|
||||
|
||||
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
|
||||
preceding_ops_;
|
||||
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
|
||||
pending_ops_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
|
||||
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
|
||||
op1->Outputs() == op2->Outputs();
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
// FIXME(zjl): Insert dependencies between some distributed ops may cause
|
||||
// the multi_devices_graph_pass fails. So we skip these ops here.
|
||||
// Indeed, maybe we should not insert dependencies between these ops
|
||||
// casually, which may cause deadlock easily.
|
||||
// We should add more skipped distributed ops when found errors in
|
||||
// multi_devices_graph_pass
|
||||
static std::unordered_set<std::string> skip_dist_ops{
|
||||
"send", "recv", "send_barrier", "fetch_barrier"};
|
||||
|
||||
auto &ops = Get<const std::vector<OpDesc *>>(kAllOpDescs);
|
||||
std::vector<ir::Node *> op_node_list;
|
||||
op_node_list.reserve(ops.size());
|
||||
|
||||
std::unordered_map<ir::Node *, size_t> op_deps;
|
||||
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
|
||||
std::unordered_set<ir::Node *> ready_ops;
|
||||
|
||||
for (ir::Node *node : graph->Nodes()) {
|
||||
if (!node->IsOp()) continue;
|
||||
std::unordered_set<ir::Node *> preceding_ops;
|
||||
for (auto *in : node->inputs) {
|
||||
PADDLE_ENFORCE(in->IsVar(),
|
||||
"Preceding Node of Op Nodes must be Var Node");
|
||||
if (in->inputs.empty()) continue;
|
||||
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
|
||||
"Preceding Op Node of Var Node must be unique");
|
||||
preceding_ops.insert(in->inputs[0]);
|
||||
pending_ops[in->inputs[0]].insert(node);
|
||||
}
|
||||
op_deps[node] = preceding_ops.size();
|
||||
if (preceding_ops.empty()) {
|
||||
ready_ops.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto *op_desc : ops) {
|
||||
ir::Node *found_node = nullptr;
|
||||
for (auto *node : ready_ops) {
|
||||
if (IsSameOpDesc(op_desc, node->Op())) {
|
||||
PADDLE_ENFORCE(found_node == nullptr,
|
||||
"Found multiple op_desc in graph: %s", op_desc->Type());
|
||||
found_node = node;
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
|
||||
op_desc->Type());
|
||||
for (auto *pending_op : pending_ops[found_node]) {
|
||||
if (--op_deps.at(pending_op) == 0) {
|
||||
ready_ops.insert(pending_op);
|
||||
}
|
||||
}
|
||||
ready_ops.erase(found_node);
|
||||
if (skip_dist_ops.count(op_desc->Type()) == 0) {
|
||||
op_node_list.push_back(found_node);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < op_node_list.size(); ++i) {
|
||||
auto *dep_var = graph->CreateControlDepVar();
|
||||
op_node_list[i]->inputs.push_back(dep_var);
|
||||
op_node_list[i - 1]->outputs.push_back(dep_var);
|
||||
dep_var->outputs.push_back(op_node_list[i]);
|
||||
dep_var->inputs.push_back(op_node_list[i - 1]);
|
||||
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
|
||||
<< " and " << op_node_list[i]->Name();
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(sequential_execution_pass,
|
||||
paddle::framework::details::SequentialExecutionPass)
|
||||
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
|
@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
constexpr char kAllOpDescs[] = "all_op_descs";
|
||||
|
||||
class SequentialExecutionPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue