parent
8837669782
commit
5be6f762d0
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/details/op_handle_graph.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static ComputationOpHandle *ConvertToComputationOpHandle(OpHandleBase *op) {
|
||||
return dynamic_cast<ComputationOpHandle *>(op);
|
||||
}
|
||||
|
||||
static bool IsLockAndRecordEventFreeComputationOpHandle(
|
||||
ComputationOpHandle *op, const OpHandleGraph &graph) {
|
||||
for (auto &pending_op : graph.PendingOps(op)) {
|
||||
auto *tmp = ConvertToComputationOpHandle(pending_op);
|
||||
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> ir_graph) const {
|
||||
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
|
||||
OpHandleGraph graph(all_ops);
|
||||
for (auto &op : all_ops) {
|
||||
auto *compute_op = ConvertToComputationOpHandle(op.get());
|
||||
if (compute_op == nullptr) continue;
|
||||
bool is_lock_and_record_event_free =
|
||||
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph);
|
||||
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
|
||||
if (is_lock_and_record_event_free) {
|
||||
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
|
||||
<< compute_op->DebugString();
|
||||
}
|
||||
}
|
||||
return ir_graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(modify_op_lock_and_record_event_pass,
|
||||
paddle::framework::details::ModifyOpLockAndRecordEventPass);
|
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class ModifyOpLockAndRecordEventPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class OpHandleGraph {
|
||||
public:
|
||||
enum Relation { kSame = 0, kBefore = 1, kAfter = 2, kNoDeps = 3 };
|
||||
|
||||
explicit OpHandleGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
|
||||
|
||||
size_t OpNumber() const;
|
||||
|
||||
std::unordered_set<OpHandleBase *> AllOps() const;
|
||||
|
||||
const std::unordered_set<OpHandleBase *> &PrecedingOps(
|
||||
OpHandleBase *op) const;
|
||||
|
||||
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
|
||||
|
||||
std::vector<std::unordered_set<OpHandleBase *>> AllPrecedingOps(
|
||||
OpHandleBase *op) const;
|
||||
|
||||
std::vector<std::unordered_set<OpHandleBase *>> AllPendingOps(
|
||||
OpHandleBase *op) const;
|
||||
|
||||
bool HasOp(OpHandleBase *op) const;
|
||||
|
||||
Relation RelationBetween(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
bool IsSame(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
bool IsBeforeOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
bool IsBefore(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
bool IsAfterOrSame(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
bool IsAfter(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
bool IsNoDeps(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
OpHandleBase *NearestCommonParent(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
// Find an operator that is after op and before op1, op2
|
||||
OpHandleBase *NearestCommonParentAfter(OpHandleBase *op, OpHandleBase *op1,
|
||||
OpHandleBase *op2) const;
|
||||
|
||||
std::unordered_set<OpHandleBase *> NoPendingOpSet() const;
|
||||
|
||||
std::unordered_set<OpHandleBase *> NoPrecedingOpSet() const;
|
||||
|
||||
private:
|
||||
void BuildGraph(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
|
||||
void EnforceHasOp(OpHandleBase *op) const;
|
||||
bool IsBeforeOrSameImpl(OpHandleBase *op1, OpHandleBase *op2) const;
|
||||
|
||||
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
|
||||
preceding_ops_;
|
||||
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
|
||||
pending_ops_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue