Remove legacy C++ memory optimization codes (#18834)
* remove legacy memory optimization codes, test=develop * follow huihuang's comments,test=develop * follow luotao's comments, test=developDDDivano-patch-1
parent
52c1431eee
commit
8008ab4e6b
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,187 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/// this attribute is used to avoid some core variables removed/reused
|
||||
/// in memory optimize related passes
|
||||
constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@";
|
||||
typedef std::unordered_set<std::string> MemOptSkipVars;
|
||||
|
||||
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
|
||||
|
||||
// NOTE(dzh): A ordered set for node reuse in memory optimize.
|
||||
// the orderedset sort node in ascend order(by node bytes size).
|
||||
// in fluid, -1 means the batch_size, which is determined in runtime.
|
||||
// So the reuse happens between nodes who's batch_size both are -1
|
||||
// simultaneously or not.
|
||||
//
|
||||
// sort rule:
|
||||
// rule 0 : smaller node ranking in front.
|
||||
// rule 1 : batch_size equal -1 ranking in the front than the node not.
|
||||
//
|
||||
// For example,
|
||||
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
|
||||
|
||||
class OrderedSet {
|
||||
public:
|
||||
// nodes with same name exists in pool.
|
||||
using NodeVector = std::vector<ir::Node*>;
|
||||
using Iter = typename std::list<NodeVector>::iterator;
|
||||
using ConstIter = typename std::list<NodeVector>::const_iterator;
|
||||
|
||||
void Insert(ir::Node* var);
|
||||
void Erase(ir::Node* var);
|
||||
void Erase(const std::string& var);
|
||||
bool Has(ir::Node* var) const;
|
||||
void Clear() {
|
||||
mark_table_.clear();
|
||||
nodes_.clear();
|
||||
}
|
||||
// find the bestfit shape node block with var.
|
||||
ir::Node* FindBestFitNode(ir::Node* var) const;
|
||||
ir::Node* FindNextBestFitNode(ir::Node* var, ir::Node* prev) const;
|
||||
// map store non-const iterator, can not promise const
|
||||
int GetNodeIndexInPool(ir::Node* var);
|
||||
// pool all node to string
|
||||
std::string ToString() const;
|
||||
|
||||
Iter begin() { return nodes_.begin(); }
|
||||
Iter end() { return nodes_.end(); }
|
||||
ConstIter begin() const { return nodes_.begin(); }
|
||||
ConstIter end() const { return nodes_.end(); }
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
private:
|
||||
// for searching.
|
||||
std::unordered_map<std::string, Iter> mark_table_;
|
||||
// node pool
|
||||
std::list<NodeVector> nodes_;
|
||||
};
|
||||
|
||||
class ControlFlowGraph {
|
||||
public:
|
||||
ControlFlowGraph() = default;
|
||||
// IR Graph
|
||||
explicit ControlFlowGraph(const ir::Graph& graph);
|
||||
|
||||
void LiveVariableAnalysis();
|
||||
|
||||
void RenameVarInCFGGraph(const std::string& old_node,
|
||||
const std::string& new_node, int begin_idx);
|
||||
|
||||
const std::set<std::string>& LiveIn(ir::Node* op) const;
|
||||
const std::set<std::string>& LiveOut(ir::Node* op) const;
|
||||
const std::set<std::string>& Use(ir::Node* op) const;
|
||||
const std::set<std::string>& Unlived(ir::Node* op) const;
|
||||
const std::vector<ir::Node*>& Ops() const;
|
||||
std::vector<ir::Node*>& Ops();
|
||||
|
||||
// for ssa-graph nodes
|
||||
ir::Node* GetNodeByName(const std::string& name, ir::Node* op) const;
|
||||
|
||||
private:
|
||||
void BuildCFGGraph();
|
||||
void ConnectNodes();
|
||||
|
||||
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
|
||||
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
|
||||
// successors ops use the output variables.
|
||||
NodeListMap successors_;
|
||||
// predecessors ops generated input variables.
|
||||
NodeListMap predecessors_;
|
||||
// variables lived before run current op.
|
||||
VarSetMap live_in_;
|
||||
// variables lived after run current op.
|
||||
VarSetMap live_out_;
|
||||
VarSetMap uses_; // op inputs
|
||||
VarSetMap defs_; // op outputs
|
||||
std::unordered_map<ir::Node*, std::set<std::string>> unlived_vars_;
|
||||
|
||||
std::vector<ir::Node*> ops_; // op sequence by topology sort
|
||||
};
|
||||
|
||||
// valid a tensor can be reuse or not
|
||||
bool NodeCanReused(ir::Node* node);
|
||||
|
||||
// valid a tensor can be reuse or not.
|
||||
bool NodeCanReused(const VarDesc& node);
|
||||
|
||||
// check op has subblock or not
|
||||
bool OpHasSubBlock(OpDesc* desc);
|
||||
|
||||
// node memory size in bytes
|
||||
size_t NodeSize(ir::Node* n);
|
||||
|
||||
// node memory size in bytes
|
||||
size_t NodeSize(const VarDesc&);
|
||||
|
||||
std::string DebugString(ir::Node* var);
|
||||
|
||||
VarDesc* GetVarDesc(ir::Node* n);
|
||||
|
||||
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
|
||||
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
|
||||
op1->Outputs() == op2->Outputs();
|
||||
}
|
||||
|
||||
template <typename Container, typename Callback>
|
||||
class FilterVariableImpl {
|
||||
public:
|
||||
void operator()(const Container& nodes, Callback callback) {
|
||||
for (auto* node : nodes) {
|
||||
callback(node);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// filter var node for op->inputs/outputs
|
||||
template <typename Callback>
|
||||
class FilterVariableImpl<std::vector<ir::Node*>, Callback> {
|
||||
public:
|
||||
void operator()(const std::vector<ir::Node*>& nodes, Callback callback) {
|
||||
for (auto* var : nodes) {
|
||||
if (var->IsVar() && !var->IsCtrlVar()) {
|
||||
callback(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Container, typename Callback>
|
||||
void FilterVariables(const Container& nodes, Callback callback) {
|
||||
FilterVariableImpl<Container, Callback>()(nodes, callback);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -1,224 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "gflags/gflags.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
|
||||
CollectSkipVarsSet(graph);
|
||||
|
||||
cfg_.reset(new ControlFlowGraph(*graph));
|
||||
cfg_->LiveVariableAnalysis();
|
||||
InitSSAGraphNodes();
|
||||
|
||||
int reuse_id = 0;
|
||||
for (size_t idx = 0; idx < cfg_->Ops().size(); ++idx) {
|
||||
auto& op = cfg_->Ops()[idx];
|
||||
auto* op_desc = op->Op();
|
||||
// some op in graph has no op desc
|
||||
if (op_desc == nullptr) continue;
|
||||
|
||||
for (auto& var : op->outputs) {
|
||||
if (var->IsVar() && !var->IsCtrlVar() && skip_set_.count(var->Name())) {
|
||||
VLOG(3) << "Skip set contains variable of " << var->Name()
|
||||
<< "disable reuse on it. skipped";
|
||||
continue;
|
||||
}
|
||||
if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) {
|
||||
ir::Node* cache = pool_.FindBestFitNode(var);
|
||||
while (cache != nullptr && var->Name() == cache->Name()) {
|
||||
VLOG(3) << "The same cache variable is cascade reused. "
|
||||
<< cache->Name() << " is re-filled to the pool after "
|
||||
<< "the reused op is finished. Current op can not "
|
||||
<< "replace it again. Skip this candidate.";
|
||||
cache = pool_.FindNextBestFitNode(var, cache);
|
||||
}
|
||||
|
||||
if (cache != nullptr) {
|
||||
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
|
||||
VLOG(3) << string::Sprintf(
|
||||
"!!! %s, %s => %s, cache idx %d, pool size %d",
|
||||
std::to_string(reuse_id++), DebugString(var), DebugString(cache),
|
||||
node_idx_in_pool, static_cast<int>(pool_.size()));
|
||||
// NOTE(dzhwinter): update the ProgramDesc/IR Graph
|
||||
// and the CFG Graph on the fly.
|
||||
//
|
||||
// IR Graph define the dependence relationship between nodes.
|
||||
//
|
||||
// ProgramDesc defines the input/output vars. Its used in
|
||||
// CreateOp, CreateVar when running happens.
|
||||
//
|
||||
// CFG Graph store the liveness information, when reuse happens
|
||||
// we also need to update the variable liveness.
|
||||
const std::string var_name = var->Name();
|
||||
const std::string cache_name = cache->Name();
|
||||
|
||||
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
|
||||
RenameVarInGraphDesc(var_name, cache_name, idx);
|
||||
RenameVarInGraphNode(var_name, cache_name, idx, graph);
|
||||
pool_.Erase(cache_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
// fill the pool
|
||||
for (auto& var : cfg_->Unlived(op)) {
|
||||
ir::Node* var_node = cfg_->GetNodeByName(var, op);
|
||||
if (var_node == nullptr || var_node->IsCtrlVar()) continue;
|
||||
if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
|
||||
pool_.Insert(var_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
graph->ResolveHazard(var_nodes_);
|
||||
}
|
||||
|
||||
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
|
||||
// fill skip_set_
|
||||
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars));
|
||||
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
|
||||
for (const auto& var : mem_opt_whitelist) {
|
||||
skip_set_.emplace(var);
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryOptimizePass::RenameVarInGraphDesc(const std::string& var,
|
||||
const std::string& cache_var,
|
||||
size_t idx) const {
|
||||
for (size_t i = idx; i < cfg_->Ops().size(); ++i) {
|
||||
auto* op = cfg_->Ops()[i];
|
||||
PADDLE_ENFORCE(op->IsOp() && op->Op());
|
||||
auto* op_desc = op->Op();
|
||||
op_desc->RenameInput(var, cache_var);
|
||||
op_desc->RenameOutput(var, cache_var);
|
||||
if (op_desc->Block() != nullptr) {
|
||||
op_desc->Block()->RemoveVar(var);
|
||||
} else {
|
||||
LOG(WARNING) << "op " << op->Name() << " not know its block."
|
||||
<< "Is the op_desc created without block pointer? "
|
||||
<< "Can not find " << var << " in Block(0)";
|
||||
}
|
||||
op_desc->Flush();
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryOptimizePass::InitSSAGraphNodes() const {
|
||||
std::unordered_map<std::string, std::unordered_set<ir::Node*>> all_vars;
|
||||
if (var_nodes_.empty()) {
|
||||
for (auto* op : cfg_->Ops()) {
|
||||
for (auto* node : op->inputs) {
|
||||
if (all_vars[node->Name()].count(node) == 0) {
|
||||
all_vars[node->Name()].emplace(node);
|
||||
var_nodes_[node->Name()].emplace_back(node);
|
||||
}
|
||||
}
|
||||
for (auto* node : op->outputs) {
|
||||
if (all_vars[node->Name()].count(node) == 0) {
|
||||
all_vars[node->Name()].emplace(node);
|
||||
var_nodes_[node->Name()].emplace_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
|
||||
const std::string& cache_var,
|
||||
size_t idx,
|
||||
ir::Graph* graph) const {
|
||||
// if replace happens, we need to create a newer version cache_var
|
||||
// but use the same dims/data_type with var.
|
||||
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
|
||||
var_nodes_[var].at(0)->Var() != nullptr);
|
||||
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
|
||||
var_desc->SetName(cache_var);
|
||||
|
||||
for (size_t i = idx; i < cfg_->Ops().size(); ++i) {
|
||||
auto* op = cfg_->Ops()[i];
|
||||
|
||||
// redirect the input to the latest version of cache_var
|
||||
for (auto* node : op->inputs) {
|
||||
if (node->Name() == var) {
|
||||
ir::Node* cache_node = var_nodes_[cache_var].back();
|
||||
|
||||
// swap node to cache_node
|
||||
cache_node->outputs.insert(cache_node->outputs.end(),
|
||||
node->outputs.begin(), node->outputs.end());
|
||||
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
|
||||
auto* prev_op = node->inputs[0];
|
||||
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
|
||||
cache_node);
|
||||
for (auto* next_op : node->outputs) {
|
||||
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
|
||||
cache_node);
|
||||
}
|
||||
|
||||
// erase unused node
|
||||
auto& nodes = var_nodes_.at(var);
|
||||
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
|
||||
graph->RemoveNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
// if we need to rename the output,
|
||||
// always create a newer version of cache_var
|
||||
for (auto* node : op->outputs) {
|
||||
if (node->Name() == var) {
|
||||
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
|
||||
var_nodes_[cache_var].emplace_back(cache_node);
|
||||
|
||||
// swap node to cache node
|
||||
cache_node->outputs.insert(cache_node->outputs.end(),
|
||||
node->outputs.begin(), node->outputs.end());
|
||||
cache_node->inputs.emplace_back(op);
|
||||
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
|
||||
for (auto* next_op : node->outputs) {
|
||||
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
|
||||
cache_node);
|
||||
}
|
||||
|
||||
// erase unused node
|
||||
auto& nodes = var_nodes_.at(var);
|
||||
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
|
||||
graph->RemoveNode(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(memory_optimize_pass, paddle::framework::ir::MemoryOptimizePass)
|
||||
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
|
@ -1,72 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class MemoryOptimizePass : public ir::Pass {
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
// fill the variable map(var_nodes) by version.
|
||||
void InitSSAGraphNodes() const;
|
||||
|
||||
private:
|
||||
// update program descs
|
||||
void RenameVarInGraphDesc(const std::string& var,
|
||||
const std::string& cache_var, size_t idx) const;
|
||||
// update ir nodes
|
||||
void RenameVarInGraphNode(const std::string& var,
|
||||
const std::string& cache_var, size_t idx,
|
||||
ir::Graph* graph) const;
|
||||
|
||||
void SubGraphOptimize(OpDesc* op_desc) const;
|
||||
// 1. scan op with subblock and collect the output/input vars.
|
||||
// while, while_grad, conditional_block
|
||||
// 2. scan distributed ops and collect the output/input vars
|
||||
// 3. op_role_vars
|
||||
void CollectSkipVarsSet(ir::Graph* graph) const;
|
||||
|
||||
private:
|
||||
// Reuse Node Pool, Owned.
|
||||
mutable OrderedSet pool_;
|
||||
// controlflow Graph
|
||||
mutable std::unique_ptr<ControlFlowGraph> cfg_;
|
||||
// skip set
|
||||
mutable std::unordered_set<std::string> skip_set_;
|
||||
// var nodes
|
||||
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,170 +0,0 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class RecordSkipMemoryOptVarsPass : public ir::Pass {
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override {
|
||||
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
|
||||
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
|
||||
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
|
||||
|
||||
std::vector<ir::Node*> op_nodes;
|
||||
for (auto& node : graph->Nodes()) {
|
||||
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
|
||||
if (node->IsOp() && node->Op()) {
|
||||
op_nodes.emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert kEmptyVarName to avoid optimizing empty variable
|
||||
skip_vars.insert(framework::kEmptyVarName);
|
||||
|
||||
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
|
||||
// in memory optimize pass.
|
||||
InsertOpRoleVarsToSkipVarSet(op_nodes, &skip_vars);
|
||||
|
||||
InsertSkipMemOptOpInOutToSkipVarSet(op_nodes, &skip_vars);
|
||||
}
|
||||
|
||||
private:
|
||||
static void InsertOpRoleVarsToSkipVarSet(const std::vector<ir::Node*>& ops,
|
||||
MemOptSkipVars* skip_vars) {
|
||||
for (auto& node : ops) {
|
||||
try {
|
||||
auto op_role_vars =
|
||||
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
|
||||
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
||||
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
|
||||
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
|
||||
auto& g_name = op_role_vars[i + 1];
|
||||
skip_vars->insert(g_name);
|
||||
}
|
||||
} catch (boost::bad_get& e) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void UpdateSkipVarSet(
|
||||
MemOptSkipVars* skip_vars,
|
||||
const std::vector<std::vector<std::string>>& var_names) {
|
||||
for (auto& var_name : var_names) {
|
||||
skip_vars->insert(var_name.begin(), var_name.end());
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<std::string> ToGradVarName(
|
||||
const std::vector<std::string>& names) {
|
||||
std::vector<std::string> ret;
|
||||
ret.reserve(names.size());
|
||||
for (auto& name : names) {
|
||||
if (name != framework::kEmptyVarName) {
|
||||
ret.emplace_back(framework::GradVarName(name));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static void InsertSkipMemOptOpInOutToSkipVarSet(
|
||||
const std::vector<ir::Node*>& ops, MemOptSkipVars* skip_vars) {
|
||||
static std::unordered_set<std::string> kSkipMemOptOps{
|
||||
"send", "recv", "prefetch", "send_barrier", "fetch_barrier"};
|
||||
|
||||
for (auto& node : ops) {
|
||||
auto* op_desc = node->Op();
|
||||
// Some ops (while, conditional_block, recurrent, etc.) have sub-blocks.
|
||||
// These ops often use variables from its parent or forward blocks.
|
||||
// Optimizing in/out of such ops would make these variables cannot
|
||||
// be found when running sub-block ops.
|
||||
if (OpHasSubBlock(op_desc)) {
|
||||
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
|
||||
op_desc->OutputArgumentNames()});
|
||||
}
|
||||
|
||||
// Skip ops that are related to parameter server.
|
||||
// In distributed mode, trainers and parameter server use same
|
||||
// variable names to track same variables. We cannot change the
|
||||
// names of these variables, otherwise trainers or parameter
|
||||
// server would not find them.
|
||||
if (kSkipMemOptOps.count(op_desc->Type()) > 0) {
|
||||
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
|
||||
op_desc->OutputArgumentNames()});
|
||||
}
|
||||
|
||||
// FIXME(zjl): some ops use variables that are not from their
|
||||
// inputs or outputs. We do not have a nice method to solve this
|
||||
// issue yet. Currently, we should skip these variables when
|
||||
// memory optimization is enabled.
|
||||
auto op_type = op_desc->Type();
|
||||
if (op_type == "while_grad") {
|
||||
// In while_grad, framework::GradVarName(Input("X")) is visited
|
||||
// without being any in/out of while_grad. While_grad uses
|
||||
// these variable to accumulate gradient of X across time steps.
|
||||
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("X"))});
|
||||
} else if (op_type == "conditional_block_grad") {
|
||||
// In conditional_block_grad, framework::GradVarName(Input("Input",
|
||||
// "Cond")) is visited without being any in/out of
|
||||
// conditional_block_grad. Conditional_block_grad uses these
|
||||
// variables to accumulate gradient of Input/Cond across time steps.
|
||||
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("Input")),
|
||||
ToGradVarName(op_desc->Input("Cond"))});
|
||||
} else if (op_type == "recurrent" || op_type == "recurrent_grad") {
|
||||
// Recurrent and recurrent_grad ops are implemented by a very trickly
|
||||
// way. Attr("states", "ex_states") is visited without being any
|
||||
// in/out of op. It is because these variables are from sub blocks,
|
||||
// not main block. Adding these variables to input would make recurrent
|
||||
// fail since "states" and "ex_states" cannot be found in main block.
|
||||
// When memory optimization is enabled, "states", "ex_states" and their
|
||||
// gradient should be skipped.
|
||||
auto ex_states =
|
||||
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
|
||||
auto states =
|
||||
boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
|
||||
if (op_type == "recurrent") {
|
||||
UpdateSkipVarSet(skip_vars, {ex_states, states});
|
||||
} else {
|
||||
// In recurrent_grad, framework::GradVarName(Input("parameters",
|
||||
// "input")) is visited without being any in/out of recurrent_grad.
|
||||
// Recurrent_grad uses these variables to accumulate gradient of
|
||||
// parameters/input across time steps.
|
||||
UpdateSkipVarSet(
|
||||
skip_vars,
|
||||
{ToGradVarName(op_desc->Input("parameters")),
|
||||
ToGradVarName(op_desc->Input("inputs")), ex_states, states,
|
||||
ToGradVarName(ex_states), ToGradVarName(states)});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(record_skip_memory_opt_vars_pass,
|
||||
paddle::framework::ir::RecordSkipMemoryOptVarsPass);
|
Loading…
Reference in new issue