parent
8837669782
commit
5be6f762d0
@ -0,0 +1,62 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
|
||||||
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||||
|
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||||
|
#include "paddle/fluid/framework/details/op_handle_graph.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
static ComputationOpHandle *ConvertToComputationOpHandle(OpHandleBase *op) {
|
||||||
|
return dynamic_cast<ComputationOpHandle *>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsLockAndRecordEventFreeComputationOpHandle(
|
||||||
|
ComputationOpHandle *op, const OpHandleGraph &graph) {
|
||||||
|
for (auto &pending_op : graph.PendingOps(op)) {
|
||||||
|
auto *tmp = ConvertToComputationOpHandle(pending_op);
|
||||||
|
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> ir_graph) const {
|
||||||
|
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
|
||||||
|
OpHandleGraph graph(all_ops);
|
||||||
|
for (auto &op : all_ops) {
|
||||||
|
auto *compute_op = ConvertToComputationOpHandle(op.get());
|
||||||
|
if (compute_op == nullptr) continue;
|
||||||
|
bool is_lock_and_record_event_free =
|
||||||
|
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph);
|
||||||
|
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
|
||||||
|
if (is_lock_and_record_event_free) {
|
||||||
|
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
|
||||||
|
<< compute_op->DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ir_graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_PASS(modify_op_lock_and_record_event_pass,
|
||||||
|
paddle::framework::details::ModifyOpLockAndRecordEventPass);
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
class ModifyOpLockAndRecordEventPass : public ir::Pass {
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,87 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
class OpHandleGraph {
|
||||||
|
public:
|
||||||
|
enum Relation { kSame = 0, kBefore = 1, kAfter = 2, kNoDeps = 3 };
|
||||||
|
|
||||||
|
explicit OpHandleGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
|
||||||
|
|
||||||
|
size_t OpNumber() const;
|
||||||
|
|
||||||
|
std::unordered_set<OpHandleBase *> AllOps() const;
|
||||||
|
|
||||||
|
const std::unordered_set<OpHandleBase *> &PrecedingOps(
|
||||||
|
OpHandleBase *op) const;
|
||||||
|
|
||||||
|
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
|
||||||
|
|
||||||
|
std::vector<std::unordered_set<OpHandleBase *>> AllPrecedingOps(
|
||||||
|
OpHandleBase *op) const;
|
||||||
|
|
||||||
|
std::vector<std::unordered_set<OpHandleBase *>> AllPendingOps(
|
||||||
|
OpHandleBase *op) const;
|
||||||
|
|
||||||
|
bool HasOp(OpHandleBase *op) const;
|
||||||
|
|
||||||
|
Relation RelationBetween(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
bool IsSame(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
bool IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
bool IsBefore(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
bool IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
bool IsAfter(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
bool IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
OpHandleBase *NearestCommonParent(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
// Find an operator that is after op and before op1, op2
|
||||||
|
OpHandleBase *NearestCommonParentAfter(OpHandleBase *op, OpHandleBase *op1,
|
||||||
|
OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
std::unordered_set<OpHandleBase *> NoPendingOpSet() const;
|
||||||
|
|
||||||
|
std::unordered_set<OpHandleBase *> NoPrecedingOpSet() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void BuildGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
|
||||||
|
void EnforceHasOp(OpHandleBase *op) const;
|
||||||
|
bool IsBeforeOrSameImpl(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||||
|
|
||||||
|
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
|
||||||
|
preceding_ops_;
|
||||||
|
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
|
||||||
|
pending_ops_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
Loading…
Reference in new issue