add ir memory optimize. (#14530)
* follow comments. test=develop * Fix typo * fix compile error. test=develop * merge develop branch. test=develop * Remove set_equal * Polish code * Delete unused functions test=develop * polish code. test=develop * follow comment * polish code. * fix windows compile error. test=develop * fix op handle. * rerun ci. test=develop * rerun ci. test=develop * rerun macci. test=develop * polish code. test=develop * rewrite sort code. test=develop * remove unused code. test=develop * fix tests. test=develop * fix conflict. test=develop * follow comment. test=develop * merge develop branch. test=develop * fix tests. test=develop * remove ToTypeIndex. test=develop * rerun ci. test=developfor_weibo
parent
fd1d2c897e
commit
7cd24b1318
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,120 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
constexpr char kAllOpDescs[] = "all_op_descs";
|
||||
|
||||
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
|
||||
// sort op in bfs order
|
||||
std::vector<ir::Node*> BFSSortGraphOps(const ir::Graph& graph);
|
||||
|
||||
class ControlFlowGraph;
|
||||
|
||||
class AnalysisVarPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
private:
|
||||
// fill the variable map(var_nodes) by version.
|
||||
void InitSSAGraphNodes() const;
|
||||
// update program descs
|
||||
void RenameVarInGraphDesc(const std::string& var,
|
||||
const std::string& cache_var, size_t idx) const;
|
||||
// update ir nodes
|
||||
void RenameVarInGraphNode(const std::string& var,
|
||||
const std::string& cache_var, size_t idx,
|
||||
ir::Graph* graph) const;
|
||||
|
||||
void SubGraphOptimize(OpDesc* op_desc) const;
|
||||
// valid a tensor can be reuse or not
|
||||
bool NodeCanReused(ir::Node* node) const;
|
||||
// scan subblock and collect the output/input variables.
|
||||
std::unordered_set<std::string> GetSubBlockVars(
|
||||
const std::unordered_set<ir::Node*>&) const;
|
||||
// check op has subblock or not
|
||||
bool OpHasSubBlock(OpDesc* desc) const;
|
||||
|
||||
private:
|
||||
// Reuse Node Pool, Owned.
|
||||
mutable OrderedNodePairPool pool_;
|
||||
// controlflow Graph
|
||||
mutable std::unique_ptr<ControlFlowGraph> cfg_;
|
||||
// skip set
|
||||
mutable std::unordered_set<std::string> skip_set_;
|
||||
// var nodes
|
||||
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
|
||||
};
|
||||
|
||||
class ControlFlowGraph {
|
||||
public:
|
||||
ControlFlowGraph() = default;
|
||||
// For IR Graph in parallelexecutor
|
||||
explicit ControlFlowGraph(const ir::Graph& graph);
|
||||
|
||||
void LiveVariableAnalysis();
|
||||
|
||||
void RenameVarInCFGGraph(const std::string& old_node,
|
||||
const std::string& new_node, int begin_idx);
|
||||
|
||||
const std::set<std::string> LiveIn(ir::Node* op) const;
|
||||
const std::set<std::string> LiveOut(ir::Node* op) const;
|
||||
const std::set<std::string> Use(ir::Node* op) const;
|
||||
const std::vector<ir::Node*> Ops() const;
|
||||
std::vector<ir::Node*>& Ops();
|
||||
|
||||
// for ssa-graph nodes
|
||||
ir::Node* GetNodeFromVarName(const std::string& name, ir::Node* op) const;
|
||||
|
||||
private:
|
||||
void BuildCFGGraph();
|
||||
void ConnectNodes();
|
||||
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
|
||||
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
|
||||
// successors ops use the output variables.
|
||||
NodeListMap successors_;
|
||||
// predecessors ops generated input variables.
|
||||
NodeListMap predecessors_;
|
||||
// variables lived before run current op.
|
||||
VarSetMap live_in_;
|
||||
// variables lived after run current op.
|
||||
VarSetMap live_out_;
|
||||
VarSetMap uses_; // op inputs
|
||||
VarSetMap defs_; // op outputs
|
||||
|
||||
std::vector<ir::Node*> ops_; // op sequence by topology sort
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,140 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/details/var_handle.h"
|
||||
#include "paddle/fluid/framework/garbage_collector.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class EarlyDeleteOpHandle : public OpHandleBase {
|
||||
public:
|
||||
EarlyDeleteOpHandle(ir::Node* node, const Scope* scope,
|
||||
const platform::Place& place,
|
||||
const std::vector<std::string>& names,
|
||||
GarbageCollector* gc)
|
||||
: OpHandleBase(node),
|
||||
scope_(scope),
|
||||
place_(place),
|
||||
names_(names),
|
||||
gc_(gc) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (IsStreamGarabageCollector()) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(place);
|
||||
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
|
||||
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
~EarlyDeleteOpHandle() {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (IsStreamGarabageCollector()) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
|
||||
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
|
||||
PADDLE_ENFORCE(cudaEventDestroy(event_));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string Name() const override { return "early_delete"; }
|
||||
|
||||
protected:
|
||||
void RunImpl() override {
|
||||
std::vector<std::shared_ptr<memory::Allocation>> tensors;
|
||||
auto* local_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope*>();
|
||||
for (auto& var_name : names_) {
|
||||
auto* var = local_scope->FindVar(var_name);
|
||||
PADDLE_ENFORCE(var != nullptr,
|
||||
string::Sprintf("Local Scope not has var %s", var_name));
|
||||
if (var->IsType<LoDTensor>()) {
|
||||
tensors.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
|
||||
} else if (var->IsType<SelectedRows>()) {
|
||||
tensors.emplace_back(var->GetMutable<SelectedRows>()
|
||||
->mutable_value()
|
||||
->MoveMemoryHolder());
|
||||
} else if (var->IsType<LoDTensorArray>()) {
|
||||
LoDTensorArray* tensor_array = var->GetMutable<LoDTensorArray>();
|
||||
for (auto& tensor : *tensor_array) {
|
||||
tensors.emplace_back(tensor.MoveMemoryHolder());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!tensors.empty()) {
|
||||
ClearTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void ClearTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
if (platform::is_cpu_place(place_)) {
|
||||
ClearCPUTensors(tensors);
|
||||
} else {
|
||||
ClearGPUTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void ClearCPUTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
auto* gc = dynamic_cast<CPUGarbageCollector*>(gc_);
|
||||
if (gc != nullptr) {
|
||||
gc->Add(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void ClearGPUTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
auto* gc = dynamic_cast<StreamGarbageCollector*>(gc_);
|
||||
if (gc != nullptr) {
|
||||
auto compute_stream = dev_ctx_->stream();
|
||||
auto callback_stream = gc->stream();
|
||||
auto callback_func = [=]() {
|
||||
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
|
||||
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
|
||||
};
|
||||
gc_->Add(tensors, callback_func);
|
||||
} else {
|
||||
gc_->Add(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsStreamGarabageCollector() const {
|
||||
return dynamic_cast<const StreamGarbageCollector*>(gc_) != nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
const Scope* scope_;
|
||||
const platform::Place place_;
|
||||
std::vector<std::string> names_;
|
||||
GarbageCollector* gc_;
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::CUDADeviceContext* dev_ctx_;
|
||||
cudaEvent_t event_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,117 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/memory_early_delete_pass.h"
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) {
|
||||
std::queue<VarHandleBase*> queue;
|
||||
queue.push(var_in);
|
||||
do {
|
||||
auto* var = queue.front();
|
||||
queue.pop();
|
||||
for (auto* op : var->PendingOps()) {
|
||||
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
|
||||
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
|
||||
return compute_op;
|
||||
}
|
||||
for (auto* out_var : op->Outputs()) {
|
||||
queue.push(out_var);
|
||||
}
|
||||
}
|
||||
} while (!queue.empty());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> MemoryEarlyDeletePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
auto& graph_pool = Get<GraphNodePool>(kGraphNodePool);
|
||||
auto& gcs = Get<GarbageCollectorMap>(kGarbageCollector);
|
||||
|
||||
std::unordered_map<std::string, std::unordered_set<OpDesc*>> unlived_vars;
|
||||
unlived_vars.reserve(graph_pool.size());
|
||||
for (auto& pair : graph_pool) {
|
||||
unlived_vars.insert(std::make_pair(pair.first, pair.second));
|
||||
}
|
||||
|
||||
auto compare_and_insert_early_delete_op = [&](
|
||||
OpHandleBase* op, const std::vector<VarHandleBase*>& vars) {
|
||||
if (unlived_vars.empty()) return;
|
||||
// unlived vars can be deleted after the last used op has finished.
|
||||
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
|
||||
const auto& places = Get<std::vector<platform::Place>>(kAllPlaces);
|
||||
for (auto& var : vars) {
|
||||
auto* var_handle = dynamic_cast<VarHandle*>(var);
|
||||
auto var_name = var->Node()->Name();
|
||||
auto& var_place = var_handle->place_;
|
||||
if (unlived_vars.count(var_name) == 0) continue;
|
||||
if (!unlived_vars[var_name].empty()) {
|
||||
if (compute_op != nullptr &&
|
||||
unlived_vars[var_name].count(compute_op->Node()->Op()) != 0) {
|
||||
unlived_vars[var_name].erase(compute_op->Node()->Op());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (var_handle == nullptr || !var_handle->Node()->IsVar() ||
|
||||
var_handle->Node()->IsCtrlVar())
|
||||
continue;
|
||||
|
||||
// shameless copyed from reference count pass.
|
||||
if (compute_op == nullptr) {
|
||||
// use next computation op scope
|
||||
compute_op = FindNextComputationOpHandle(var_handle);
|
||||
}
|
||||
auto* early_delete_node =
|
||||
graph->CreateEmptyNode("early_delete", ir::Node::Type::kOperation);
|
||||
GarbageCollector* gc = gcs.at(places[compute_op->GetScopeIdx()]).get();
|
||||
auto* early_delete_handle = new EarlyDeleteOpHandle(
|
||||
early_delete_node, compute_op->GetScope(), var_place, {var_name}, gc);
|
||||
if (compute_op->Outputs().empty()) {
|
||||
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
|
||||
compute_op->AddOutput(dep_var);
|
||||
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
||||
}
|
||||
early_delete_handle->AddInput(compute_op->Outputs().front());
|
||||
VLOG(5) << "Add early delete op " << var_name << " to Operator"
|
||||
<< compute_op->Name();
|
||||
}
|
||||
};
|
||||
|
||||
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
|
||||
for (auto& op : all_ops) {
|
||||
compare_and_insert_early_delete_op(op, op->Inputs());
|
||||
compare_and_insert_early_delete_op(op, op->Outputs());
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(memory_early_delete_pass,
|
||||
paddle::framework::details::MemoryEarlyDeletePass)
|
||||
.RequireGraphAttr(paddle::framework::details::kGraphNodePool)
|
||||
.RequireGraphAttr(paddle::framework::details::kGarbageCollector);
|
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/details/early_delete_op_handle.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class MemoryEarlyDeletePass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,155 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
size_t NodeSizeInBytes(ir::Node* n) {
|
||||
auto* desc = FindVarDescInBlock(n);
|
||||
auto shape = desc->GetShape();
|
||||
size_t type_size = SizeOfType(desc->GetDataType());
|
||||
int size = 1;
|
||||
for (auto& s : shape) {
|
||||
size *= s;
|
||||
}
|
||||
return type_size * std::abs(size);
|
||||
}
|
||||
|
||||
std::string DebugStringImpl(VarDesc* var) {
|
||||
std::stringstream ss;
|
||||
ss << var->Name();
|
||||
ss << "[";
|
||||
try {
|
||||
auto shape = var->GetShape();
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (i != shape.size() - 1) {
|
||||
ss << shape[i] << ",";
|
||||
} else {
|
||||
ss << shape[i];
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
} catch (...) {
|
||||
ss << "Var has no VarDesc !!! Name:" << var->Name();
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string DebugString(ir::Node* var) {
|
||||
return DebugStringImpl(FindVarDescInBlock(var));
|
||||
}
|
||||
// return DebugString(var->Var()); }
|
||||
|
||||
// NOTE(dzh): based ir node, if a large node has been reused
|
||||
// by a small size node, then next time it appear in pool, it will
|
||||
// have the small size. Find the original node shap from blockdesc.
|
||||
VarDesc* FindVarDescInBlock(ir::Node* n) {
|
||||
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
|
||||
BlockDesc* block = n->inputs[0]->Op()->Block();
|
||||
PADDLE_ENFORCE(block->HasVar(n->Name()),
|
||||
string::Sprintf("Block do not has var %s", n->Name()));
|
||||
return block->FindVar(n->Name());
|
||||
}
|
||||
|
||||
struct NodeComparator {
|
||||
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
|
||||
auto* lhs_desc = FindVarDescInBlock(lhs);
|
||||
auto* rhs_desc = FindVarDescInBlock(rhs);
|
||||
auto lhs_shape = lhs_desc->GetShape();
|
||||
auto rhs_shape = rhs_desc->GetShape();
|
||||
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
|
||||
(lhs_shape[0] != -1 && rhs_shape[0] != -1)) {
|
||||
return NodeSizeInBytes(lhs) <= NodeSizeInBytes(rhs);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) {
|
||||
PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar());
|
||||
PADDLE_ENFORCE(op->IsOp());
|
||||
if (mark_table_.count(var->Name()) != 0) {
|
||||
mark_table_[var->Name()]->second.insert(op);
|
||||
return;
|
||||
}
|
||||
|
||||
auto* var_desc = FindVarDescInBlock(var);
|
||||
auto var_shape = var_desc->GetShape();
|
||||
int batch_size = static_cast<int>(var_shape[0]);
|
||||
|
||||
NodeComparator compare_node;
|
||||
Iter it = nodes_.begin();
|
||||
while (it != nodes_.end()) {
|
||||
auto* cache_desc = FindVarDescInBlock(it->first);
|
||||
int cache_batch_size = cache_desc->GetShape()[0];
|
||||
if ((cache_batch_size == -1 && batch_size == -1) ||
|
||||
(cache_batch_size != -1 && batch_size != -1)) {
|
||||
if (compare_node(it->first, var)) {
|
||||
++it;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else if (cache_batch_size == -1 && batch_size != -1) {
|
||||
++it;
|
||||
} else if (cache_batch_size != -1 && batch_size == -1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
it =
|
||||
nodes_.insert(it, std::make_pair(var, std::unordered_set<ir::Node*>{op}));
|
||||
mark_table_[var->Name()] = it;
|
||||
}
|
||||
|
||||
int OrderedNodePairPool::GetIndex(ir::Node* var) {
|
||||
return std::distance(nodes_.begin(), mark_table_[var->Name()]);
|
||||
}
|
||||
|
||||
ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const {
|
||||
ir::Node* found_node = nullptr;
|
||||
NodeComparator compare_node;
|
||||
|
||||
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
|
||||
if (compare_node(var, it->first)) {
|
||||
found_node = it->first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return found_node;
|
||||
}
|
||||
|
||||
void OrderedNodePairPool::Erase(ir::Node* var) {
|
||||
PADDLE_ENFORCE(mark_table_.count(var->Name()));
|
||||
nodes_.erase(mark_table_[var->Name()]);
|
||||
mark_table_.erase(var->Name());
|
||||
}
|
||||
|
||||
std::string OrderedNodePairPool::ToString() const {
|
||||
std::stringstream ss;
|
||||
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
|
||||
ss << DebugString(it->first) << " ";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
constexpr char kFetchedVars[] = "fetched_vars";
|
||||
constexpr char kGraphNodePool[] = "graph_node_pool";
|
||||
|
||||
// NOTE(dzh): Variable and the operators use the var.
|
||||
// for early delete pass.
|
||||
// Because analysis var pass build base on ir::Node, which maybe released
|
||||
// or modified between passes, so we use OpDesc* to mark ops.
|
||||
using GraphNodePool = std::vector<
|
||||
std::pair<std::string /*var node*/, std::unordered_set<OpDesc*> /* ops */>>;
|
||||
|
||||
// NOTE(dzh): by default, it sort node in ascend order(by node bytes size).
|
||||
// in fluid, -1 means the batch_size is determined in runtime.
|
||||
// the node batch_size equal -1 always ranking in the front than the node not.
|
||||
// For example,
|
||||
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
|
||||
// O(1) insert, delete
|
||||
class OrderedNodePairPool {
|
||||
public:
|
||||
using NodePair = std::pair<ir::Node*, std::unordered_set<ir::Node*>>;
|
||||
using Iter = typename std::list<NodePair>::iterator;
|
||||
using ConstIter = typename std::list<NodePair>::const_iterator;
|
||||
|
||||
void Insert(ir::Node* var, ir::Node* op);
|
||||
|
||||
void Erase(ir::Node* var);
|
||||
|
||||
bool Has(ir::Node* var) { return mark_table_.count(var->Name()); }
|
||||
|
||||
ir::Node* NodeMatch(ir::Node* var) const;
|
||||
// map store non-const iterator, can not promise const
|
||||
int GetIndex(ir::Node* var);
|
||||
// pool all node to string
|
||||
std::string ToString() const;
|
||||
|
||||
Iter begin() { return nodes_.begin(); }
|
||||
Iter end() { return nodes_.end(); }
|
||||
ConstIter begin() const { return nodes_.begin(); }
|
||||
ConstIter end() const { return nodes_.end(); }
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
private:
|
||||
// for searching.
|
||||
std::unordered_map<std::string, Iter> mark_table_;
|
||||
// node swap pairs. var -> ops dep var
|
||||
std::list<NodePair> nodes_;
|
||||
};
|
||||
|
||||
// node memory size in bytes
|
||||
size_t NodeSizeInBytes(ir::Node* n);
|
||||
|
||||
std::string DebugString(ir::Node* var);
|
||||
|
||||
// std::string DebugString(VarDesc* var);
|
||||
VarDesc* FindVarDescInBlock(ir::Node* n);
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/memory_reuse_types.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
TEST(OrderedNodePairPool, Normal) {
|
||||
OrderedNodePairPool pool;
|
||||
std::vector<std::unique_ptr<ir::Node>> nodes;
|
||||
|
||||
// clang-format off
|
||||
std::vector<std::vector<int64_t>> shapes = {{-1, 10},
|
||||
{-1, 20},
|
||||
{1, 2},
|
||||
{5, 2},
|
||||
{10, 20},
|
||||
{-1, 2, 5},
|
||||
{-1, 1, 5},
|
||||
{-1, 1}};
|
||||
// clang-format on
|
||||
const int COUNT = shapes.size();
|
||||
ProgramDesc prog;
|
||||
BlockDesc* block_desc = prog.MutableBlock(0);
|
||||
auto* op_desc = block_desc->AppendOp();
|
||||
op_desc->SetType("dummy");
|
||||
std::unique_ptr<ir::Node> op = ir::CreateNodeForTest(op_desc);
|
||||
|
||||
for (int i = 0; i < COUNT; ++i) {
|
||||
auto desc = block_desc->Var(std::to_string(i));
|
||||
desc->SetShape(shapes[i]);
|
||||
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
|
||||
node->inputs.emplace_back(op.get());
|
||||
nodes.emplace_back(std::move(node));
|
||||
}
|
||||
|
||||
for (auto& node : nodes) {
|
||||
pool.Insert(node.get(), op.get());
|
||||
}
|
||||
|
||||
// assert its order and interface.
|
||||
std::cout << pool.ToString() << std::endl;
|
||||
pool.Erase(nodes.front().get());
|
||||
std::cout << pool.ToString() << std::endl;
|
||||
|
||||
ASSERT_EQ(pool.size(), static_cast<size_t>(COUNT - 1));
|
||||
ASSERT_EQ(pool.GetIndex(nodes.back().get()), 0);
|
||||
|
||||
{
|
||||
auto v1 = block_desc->Var("11");
|
||||
v1->SetShape({-1, 256, 56, 56});
|
||||
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v1);
|
||||
node1->inputs.emplace_back(op.get());
|
||||
auto* cache = pool.NodeMatch(node1.get());
|
||||
ASSERT_EQ(cache, nullptr);
|
||||
}
|
||||
{
|
||||
auto v2 = block_desc->Var("12");
|
||||
v2->SetShape({-1, 2, 5});
|
||||
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v2);
|
||||
node1->inputs.emplace_back(op.get());
|
||||
auto* cache = pool.NodeMatch(node1.get());
|
||||
ASSERT_EQ(pool.GetIndex(cache), 2); // match 6:[-1,2,5]
|
||||
}
|
||||
{
|
||||
auto v3 = block_desc->Var("13");
|
||||
v3->SetShape({2, 5});
|
||||
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v3);
|
||||
node1->inputs.emplace_back(op.get());
|
||||
auto* cache = pool.NodeMatch(node1.get());
|
||||
ASSERT_EQ(pool.GetIndex(cache), 5); // match 4:[5,2]
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue