|
|
|
@ -20,19 +20,16 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
OpGraphView::OpGraphView(
|
|
|
|
|
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
|
|
|
|
|
Build(ops);
|
|
|
|
|
}
|
|
|
|
|
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
|
|
|
|
|
|
|
|
|
|
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
|
|
|
|
|
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
|
|
|
|
|
for (auto &op : ops) {
|
|
|
|
|
preceding_ops_[op.get()];
|
|
|
|
|
pending_ops_[op.get()];
|
|
|
|
|
preceding_ops_[op];
|
|
|
|
|
pending_ops_[op];
|
|
|
|
|
for (auto &var : op->Outputs()) {
|
|
|
|
|
for (auto &pending_op : var->PendingOps()) {
|
|
|
|
|
preceding_ops_[pending_op].insert(op.get());
|
|
|
|
|
pending_ops_[op.get()].insert(pending_op);
|
|
|
|
|
preceding_ops_[pending_op].insert(op);
|
|
|
|
|
pending_ops_[op].insert(pending_op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
|
|
|
|
|
"There are duplicate ops in graph.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
|
|
|
|
|
|
|
|
|
|
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
|
|
|
|
|
std::unordered_set<OpHandleBase *> ret;
|
|
|
|
|
for (auto &pair : preceding_ops_) {
|
|
|
|
@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
|
|
|
|
|
op == nullptr ? "nullptr" : op->DebugString());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
|
|
|
|
|
OpHandleBase *op) const {
|
|
|
|
|
EnforceHasOp(op);
|
|
|
|
|
return preceding_ops_.at(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
|
|
|
|
|
OpHandleBase *op) const {
|
|
|
|
|
EnforceHasOp(op);
|
|
|
|
|