parent
266e625d2e
commit
8f3b252392
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,74 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class GraphView {
|
||||
public:
|
||||
GraphView() = default;
|
||||
|
||||
void Build(ir::Graph* g);
|
||||
|
||||
const std::vector<ir::Node*> AllOps();
|
||||
|
||||
ir::Node* GetNodeByName(const std::string& name,
|
||||
const std::vector<ir::Node*>& nodes) const;
|
||||
|
||||
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
|
||||
|
||||
bool OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var);
|
||||
|
||||
private:
|
||||
std::vector<ir::Node*> ops_;
|
||||
};
|
||||
|
||||
class InplacePass : public ir::Pass {
|
||||
public:
|
||||
InplacePass();
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
void InitSSAGraphNodes() const;
|
||||
|
||||
private:
|
||||
void InplaceModifyVar(const std::string& in_var, const std::string& out_var,
|
||||
const size_t& idx, ir::Graph* graph) const;
|
||||
|
||||
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
|
||||
const size_t& idx) const;
|
||||
|
||||
void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const;
|
||||
|
||||
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
|
||||
|
||||
mutable std::unordered_set<std::string> whitelist_;
|
||||
mutable GraphView view_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,135 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/framework/block_desc.h"
|
||||
#include "paddle/fluid/framework/op_desc.h"
|
||||
#include "paddle/fluid/framework/type_defs.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
/*
|
||||
Inplace Inference for create In->Out pairs for inplaced operator.
|
||||
If we specify a pair of corresponding names. For example, X->Out.
|
||||
then Out will inplaced use X's memory. The base class will do
|
||||
legality validation for both variables.
|
||||
*/
|
||||
class InplaceOpInference {
|
||||
public:
|
||||
virtual ~InplaceOpInference() {}
|
||||
virtual std::unordered_map<std::string, std::string> operator()(
|
||||
const OpDesc& op_desc, BlockDesc* block) const = 0;
|
||||
};
|
||||
|
||||
class InplaceInToOut : public InplaceOpInference {
|
||||
public:
|
||||
std::unordered_map<std::string, std::string> operator()(
|
||||
const OpDesc& op_desc, BlockDesc* block) const {
|
||||
std::unordered_map<std::string, std::string> ret;
|
||||
auto in_out_var_names_pair = this->Apply(op_desc, block);
|
||||
for (auto& pair : in_out_var_names_pair) {
|
||||
PADDLE_ENFORCE(!op_desc.Input(pair.first).empty(),
|
||||
string::Sprintf("op %s do not have input of %s!",
|
||||
op_desc.Type(), pair.first));
|
||||
PADDLE_ENFORCE(!op_desc.Output(pair.second).empty(),
|
||||
string::Sprintf("op %s do not have output of %s!",
|
||||
op_desc.Type(), pair.second));
|
||||
auto& in_name = op_desc.Input(pair.first).at(0);
|
||||
auto& out_name = op_desc.Output(pair.second).at(0);
|
||||
|
||||
auto in = block->FindRecursiveOrCreateVar(in_name);
|
||||
auto out = block->FindRecursiveOrCreateVar(out_name);
|
||||
if (TryInplaceInputOutput(in, out)) ret.insert({in_name, out_name});
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual std::unordered_map<std::string, std::string> Apply(
|
||||
const OpDesc& op_desc, BlockDesc* block) const = 0;
|
||||
|
||||
bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const {
|
||||
auto var_can_reused = [&](const VarDesc& node) -> bool {
|
||||
auto type = node.GetType();
|
||||
if (node.Persistable() || type != proto::VarType::LOD_TENSOR ||
|
||||
node.GetShape().empty()) {
|
||||
return false;
|
||||
}
|
||||
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
|
||||
std::string name = node.Name();
|
||||
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
auto var_size_in_bytes = [&](const VarDesc& node) -> size_t {
|
||||
auto shape = node.GetShape();
|
||||
int size = std::accumulate(shape.begin(), shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
size_t type_size = SizeOfType(node.GetDataType());
|
||||
return type_size * std::abs(size);
|
||||
};
|
||||
|
||||
return in.Name() != out.Name() && var_can_reused(in) &&
|
||||
var_can_reused(out) &&
|
||||
var_size_in_bytes(out) <= var_size_in_bytes(in);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Inplace In and Out for operator only have an Input and an Output.
|
||||
For example, activation op.
|
||||
*/
|
||||
class SingleOpInplaceInToOut : public InplaceInToOut {
|
||||
protected:
|
||||
std::unordered_map<std::string, std::string> Apply(
|
||||
const OpDesc& op_desc, BlockDesc* block) const override {
|
||||
PADDLE_ENFORCE(!op_desc.InputNames().empty(),
|
||||
"Op inputs must not be empty");
|
||||
PADDLE_ENFORCE(!op_desc.OutputNames().empty(),
|
||||
"Op outputs must not be empty");
|
||||
auto x_name = op_desc.InputNames().at(0);
|
||||
auto out_name = op_desc.OutputNames().at(0);
|
||||
return std::unordered_map<std::string, std::string>{{x_name, out_name}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Gradient op. Inplace output use it's Input.
|
||||
For example, Input@Grad->Input reuse strategy.
|
||||
*/
|
||||
class GradOpInplaceInToOut : public InplaceInToOut {
|
||||
protected:
|
||||
std::unordered_map<std::string, std::string> Apply(
|
||||
const OpDesc& op_desc, BlockDesc* block) const override {
|
||||
std::unordered_map<std::string, std::string> ret;
|
||||
std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(),
|
||||
op_desc.OutputNames().end());
|
||||
for (auto& input_name : op_desc.InputNames()) {
|
||||
if (output_names.count(GradVarName(input_name))) {
|
||||
ret.insert({input_name, GradVarName(input_name)});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue