commit
bc95a4ccfe
File diff suppressed because it is too large
Load Diff
@ -1,120 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
constexpr char kAllOpDescs[] = "all_op_descs";
|
||||
|
||||
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
|
||||
// sort op in bfs order
|
||||
std::vector<ir::Node*> BFSSortGraphOps(const ir::Graph& graph);
|
||||
|
||||
class ControlFlowGraph;
|
||||
|
||||
class AnalysisVarPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
private:
|
||||
// fill the variable map(var_nodes) by version.
|
||||
void InitSSAGraphNodes() const;
|
||||
// update program descs
|
||||
void RenameVarInGraphDesc(const std::string& var,
|
||||
const std::string& cache_var, size_t idx) const;
|
||||
// update ir nodes
|
||||
void RenameVarInGraphNode(const std::string& var,
|
||||
const std::string& cache_var, size_t idx,
|
||||
ir::Graph* graph) const;
|
||||
|
||||
void SubGraphOptimize(OpDesc* op_desc) const;
|
||||
// valid a tensor can be reuse or not
|
||||
bool NodeCanReused(ir::Node* node) const;
|
||||
// scan subblock and collect the output/input variables.
|
||||
std::unordered_set<std::string> GetSubBlockVars(
|
||||
const std::unordered_set<ir::Node*>&) const;
|
||||
// check op has subblock or not
|
||||
bool OpHasSubBlock(OpDesc* desc) const;
|
||||
|
||||
private:
|
||||
// Reuse Node Pool, Owned.
|
||||
mutable OrderedNodePairPool pool_;
|
||||
// controlflow Graph
|
||||
mutable std::unique_ptr<ControlFlowGraph> cfg_;
|
||||
// skip set
|
||||
mutable std::unordered_set<std::string> skip_set_;
|
||||
// var nodes
|
||||
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
|
||||
};
|
||||
|
||||
class ControlFlowGraph {
|
||||
public:
|
||||
ControlFlowGraph() = default;
|
||||
// For IR Graph in parallelexecutor
|
||||
explicit ControlFlowGraph(const ir::Graph& graph);
|
||||
|
||||
void LiveVariableAnalysis();
|
||||
|
||||
void RenameVarInCFGGraph(const std::string& old_node,
|
||||
const std::string& new_node, int begin_idx);
|
||||
|
||||
const std::set<std::string> LiveIn(ir::Node* op) const;
|
||||
const std::set<std::string> LiveOut(ir::Node* op) const;
|
||||
const std::set<std::string> Use(ir::Node* op) const;
|
||||
const std::vector<ir::Node*> Ops() const;
|
||||
std::vector<ir::Node*>& Ops();
|
||||
|
||||
// for ssa-graph nodes
|
||||
ir::Node* GetNodeFromVarName(const std::string& name, ir::Node* op) const;
|
||||
|
||||
private:
|
||||
void BuildCFGGraph();
|
||||
void ConnectNodes();
|
||||
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
|
||||
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
|
||||
// successors ops use the output variables.
|
||||
NodeListMap successors_;
|
||||
// predecessors ops generated input variables.
|
||||
NodeListMap predecessors_;
|
||||
// variables lived before run current op.
|
||||
VarSetMap live_in_;
|
||||
// variables lived after run current op.
|
||||
VarSetMap live_out_;
|
||||
VarSetMap uses_; // op inputs
|
||||
VarSetMap defs_; // op outputs
|
||||
|
||||
std::vector<ir::Node*> ops_; // op sequence by topology sort
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include "glog/logging.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class DummyOp : public OperatorBase {
|
||||
public:
|
||||
DummyOp(const std::string& type, const VariableNameMap& inputs,
|
||||
const VariableNameMap& outputs, const AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void RunImpl(const Scope& scope,
|
||||
const platform::Place& place) const override {}
|
||||
};
|
||||
|
||||
class SumOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "").AsDuplicable();
|
||||
AddOutput("Out", "");
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class AssignOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "").AsDuplicable();
|
||||
AddOutput("Out", "");
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class SplitOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X", "");
|
||||
AddOutput("Out", "").AsDuplicable();
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class DummyVarTypeInference : public VarTypeInference {
|
||||
public:
|
||||
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
|
||||
auto& inputs = op_desc.Input("X");
|
||||
auto type = block->Var(inputs.front())->GetType();
|
||||
auto out_var_name = op_desc.Output("Out").front();
|
||||
block->Var(out_var_name)->SetType(type);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,94 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may abtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class GraphView {
|
||||
public:
|
||||
GraphView() = default;
|
||||
|
||||
void Build(ir::Graph* g);
|
||||
|
||||
const std::vector<ir::Node*>& AllOps();
|
||||
|
||||
ir::Node* GetNodeByName(const std::string& name,
|
||||
const std::vector<ir::Node*>& nodes) const;
|
||||
|
||||
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
|
||||
|
||||
// Will Deperated in the future.
|
||||
// NOTE(dzhwinter) :
|
||||
// 1. Python memory optimize will reuse
|
||||
// memory based var name, so different op output may
|
||||
// have the same variable name. enable inplace on such node
|
||||
// will generate a circle in ssa graph.
|
||||
// 2. DistributeTranspiler will use unique name to
|
||||
// map the parameter and gradient, must be skipped.
|
||||
bool InSkipSet(const std::string& var) const;
|
||||
|
||||
private:
|
||||
std::vector<ir::Node*> ops_;
|
||||
std::unordered_set<std::string> dup_nodes_; // mem opt affect nodes
|
||||
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
|
||||
};
|
||||
|
||||
// swap pairs in sequence
|
||||
typedef std::vector<std::pair<ir::Node*, ir::Node*>> NodeSwapQueue;
|
||||
class InplacePass : public ir::Pass {
|
||||
public:
|
||||
InplacePass();
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
void InitSSAGraphNodes() const;
|
||||
|
||||
private:
|
||||
const NodeSwapQueue TryInplaceModifyVar(const std::string& var,
|
||||
const std::string& cache_var,
|
||||
const size_t& idx,
|
||||
ir::Graph* graph) const;
|
||||
|
||||
void CommitModify(const NodeSwapQueue&, ir::Graph* graph) const;
|
||||
|
||||
void WithdrawModify(const NodeSwapQueue& nodes, ir::Graph* graph) const;
|
||||
|
||||
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
|
||||
const size_t& idx) const;
|
||||
|
||||
void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const;
|
||||
|
||||
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
|
||||
|
||||
mutable std::unordered_set<std::string> whitelist_;
|
||||
mutable GraphView view_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,117 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/memory_early_delete_pass.h"
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) {
|
||||
std::queue<VarHandleBase*> queue;
|
||||
queue.push(var_in);
|
||||
do {
|
||||
auto* var = queue.front();
|
||||
queue.pop();
|
||||
for (auto* op : var->PendingOps()) {
|
||||
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
|
||||
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place()) {
|
||||
return compute_op;
|
||||
}
|
||||
for (auto* out_var : op->Outputs()) {
|
||||
queue.push(out_var);
|
||||
}
|
||||
}
|
||||
} while (!queue.empty());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> MemoryEarlyDeletePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
auto& graph_pool = Get<GraphNodePool>(kGraphNodePool);
|
||||
auto& gcs = Get<GarbageCollectorMap>(kGarbageCollector);
|
||||
|
||||
std::unordered_map<std::string, std::unordered_set<OpDesc*>> unlived_vars;
|
||||
unlived_vars.reserve(graph_pool.size());
|
||||
for (auto& pair : graph_pool) {
|
||||
unlived_vars.insert(std::make_pair(pair.first, pair.second));
|
||||
}
|
||||
|
||||
auto compare_and_insert_early_delete_op = [&](
|
||||
OpHandleBase* op, const std::vector<VarHandleBase*>& vars) {
|
||||
if (unlived_vars.empty()) return;
|
||||
// unlived vars can be deleted after the last used op has finished.
|
||||
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
|
||||
const auto& places = Get<std::vector<platform::Place>>(kAllPlaces);
|
||||
for (auto& var : vars) {
|
||||
auto* var_handle = dynamic_cast<VarHandle*>(var);
|
||||
auto var_name = var->Node()->Name();
|
||||
auto& var_place = var_handle->place();
|
||||
if (unlived_vars.count(var_name) == 0) continue;
|
||||
if (!unlived_vars[var_name].empty()) {
|
||||
if (compute_op != nullptr &&
|
||||
unlived_vars[var_name].count(compute_op->Node()->Op()) != 0) {
|
||||
unlived_vars[var_name].erase(compute_op->Node()->Op());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (var_handle == nullptr || !var_handle->Node()->IsVar() ||
|
||||
var_handle->Node()->IsCtrlVar())
|
||||
continue;
|
||||
|
||||
// shameless copyed from reference count pass.
|
||||
if (compute_op == nullptr) {
|
||||
// use next computation op scope
|
||||
compute_op = FindNextComputationOpHandle(var_handle);
|
||||
}
|
||||
auto* early_delete_node =
|
||||
graph->CreateEmptyNode("early_delete", ir::Node::Type::kOperation);
|
||||
GarbageCollector* gc = gcs.at(places[compute_op->GetScopeIdx()]).get();
|
||||
auto* early_delete_handle = new EarlyDeleteOpHandle(
|
||||
early_delete_node, compute_op->GetScope(), var_place, {var_name}, gc);
|
||||
if (compute_op->Outputs().empty()) {
|
||||
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
|
||||
compute_op->AddOutput(dep_var);
|
||||
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
||||
}
|
||||
early_delete_handle->AddInput(compute_op->Outputs().front());
|
||||
VLOG(5) << "Add early delete op " << var_name << " to Operator"
|
||||
<< compute_op->Name();
|
||||
}
|
||||
};
|
||||
|
||||
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
|
||||
for (auto& op : all_ops) {
|
||||
compare_and_insert_early_delete_op(op, op->Inputs());
|
||||
compare_and_insert_early_delete_op(op, op->Outputs());
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(memory_early_delete_pass,
|
||||
paddle::framework::details::MemoryEarlyDeletePass)
|
||||
.RequireGraphAttr(paddle::framework::details::kGraphNodePool)
|
||||
.RequireGraphAttr(paddle::framework::details::kGarbageCollector);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,182 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
constexpr char kAllOpDescs[] = "all_op_descs";
|
||||
|
||||
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
|
||||
|
||||
// NOTE(dzh): A ordered set for node reuse in memory optimize.
|
||||
// the orderedset sort node in ascend order(by node bytes size).
|
||||
// in fluid, -1 means the batch_size, which is determined in runtime.
|
||||
// So the reuse happens between nodes who's batch_size both are -1
|
||||
// simultaneously or not.
|
||||
//
|
||||
// sort rule:
|
||||
// rule 0 : smaller node ranking in front.
|
||||
// rule 1 : batch_size equal -1 ranking in the front than the node not.
|
||||
//
|
||||
// For example,
|
||||
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
|
||||
|
||||
class OrderedSet {
|
||||
public:
|
||||
// nodes with same name exists in pool.
|
||||
using NodeVector = std::vector<ir::Node*>;
|
||||
using Iter = typename std::list<NodeVector>::iterator;
|
||||
using ConstIter = typename std::list<NodeVector>::const_iterator;
|
||||
|
||||
void Insert(ir::Node* var);
|
||||
void Erase(ir::Node* var);
|
||||
bool Has(ir::Node* var) const;
|
||||
void Clear() {
|
||||
mark_table_.clear();
|
||||
nodes_.clear();
|
||||
}
|
||||
// find the bestfit shape node block with var.
|
||||
ir::Node* FindBestFitNode(ir::Node* var) const;
|
||||
// map store non-const iterator, can not promise const
|
||||
int GetNodeIndexInPool(ir::Node* var);
|
||||
// pool all node to string
|
||||
std::string ToString() const;
|
||||
|
||||
Iter begin() { return nodes_.begin(); }
|
||||
Iter end() { return nodes_.end(); }
|
||||
ConstIter begin() const { return nodes_.begin(); }
|
||||
ConstIter end() const { return nodes_.end(); }
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
private:
|
||||
// for searching.
|
||||
std::unordered_map<std::string, Iter> mark_table_;
|
||||
// node pool
|
||||
std::list<NodeVector> nodes_;
|
||||
};
|
||||
|
||||
class ControlFlowGraph {
|
||||
public:
|
||||
ControlFlowGraph() = default;
|
||||
// IR Graph
|
||||
explicit ControlFlowGraph(const ir::Graph& graph);
|
||||
|
||||
void LiveVariableAnalysis();
|
||||
|
||||
void RenameVarInCFGGraph(const std::string& old_node,
|
||||
const std::string& new_node, int begin_idx);
|
||||
|
||||
const std::set<std::string> LiveIn(ir::Node* op) const;
|
||||
const std::set<std::string> LiveOut(ir::Node* op) const;
|
||||
const std::set<std::string> Use(ir::Node* op) const;
|
||||
const std::vector<ir::Node*> Ops() const;
|
||||
std::vector<ir::Node*>& Ops();
|
||||
|
||||
// for ssa-graph nodes
|
||||
ir::Node* GetNodeByName(const std::string& name, ir::Node* op) const;
|
||||
|
||||
private:
|
||||
void BuildCFGGraph();
|
||||
void ConnectNodes();
|
||||
|
||||
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
|
||||
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
|
||||
// successors ops use the output variables.
|
||||
NodeListMap successors_;
|
||||
// predecessors ops generated input variables.
|
||||
NodeListMap predecessors_;
|
||||
// variables lived before run current op.
|
||||
VarSetMap live_in_;
|
||||
// variables lived after run current op.
|
||||
VarSetMap live_out_;
|
||||
VarSetMap uses_; // op inputs
|
||||
VarSetMap defs_; // op outputs
|
||||
|
||||
std::vector<ir::Node*> ops_; // op sequence by topology sort
|
||||
};
|
||||
|
||||
// valid a tensor can be reuse or not
|
||||
bool NodeCanReused(ir::Node* node);
|
||||
|
||||
// valid a tensor can be reuse or not.
|
||||
bool NodeCanReused(const VarDesc& node);
|
||||
|
||||
// check op has subblock or not
|
||||
bool OpHasSubBlock(OpDesc* desc);
|
||||
|
||||
// node memory size in bytes
|
||||
size_t NodeSize(ir::Node* n);
|
||||
|
||||
// node memory size in bytes
|
||||
size_t NodeSize(const VarDesc&);
|
||||
|
||||
std::string DebugString(ir::Node* var);
|
||||
|
||||
// NOTE(dzhwinter)
|
||||
// after node reuse, the replaced node shape is
|
||||
// different with its VarDesc. So need to find the
|
||||
// correct VarDesc in Block.
|
||||
VarDesc* FindVarDescInBlock(ir::Node* n);
|
||||
|
||||
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
|
||||
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
|
||||
op1->Outputs() == op2->Outputs();
|
||||
}
|
||||
|
||||
template <typename Container, typename Callback>
|
||||
class FilterVariableImpl {
|
||||
public:
|
||||
void operator()(const Container& nodes, Callback callback) {
|
||||
for (auto* node : nodes) {
|
||||
callback(node);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// filter var node for op->inputs/outputs
|
||||
template <typename Callback>
|
||||
class FilterVariableImpl<std::vector<ir::Node*>, Callback> {
|
||||
public:
|
||||
void operator()(const std::vector<ir::Node*>& nodes, Callback callback) {
|
||||
for (auto* var : nodes) {
|
||||
if (var->IsVar() && !var->IsCtrlVar()) {
|
||||
callback(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Container, typename Callback>
|
||||
void FilterVariables(const Container& nodes, Callback callback) {
|
||||
FilterVariableImpl<Container, Callback>()(nodes, callback);
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue