|
|
|
@ -14,19 +14,19 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/prune.h"
|
|
|
|
|
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
const std::string kFeedOpType = "feed";
|
|
|
|
|
const std::string kFetchOpType = "fetch";
|
|
|
|
|
const char kFeedOpType[] = "feed";
|
|
|
|
|
const char kFetchOpType[] = "fetch";
|
|
|
|
|
|
|
|
|
|
bool HasDependentVar(const proto::OpDesc& op_desc,
|
|
|
|
|
const std::set<std::string>& dependent_vars) {
|
|
|
|
@ -68,7 +68,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
|
|
|
|
|
// the child block to help pruning
|
|
|
|
|
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
int block_id, int parent_block_id,
|
|
|
|
|
std::set<std::string>& dependent_vars) {
|
|
|
|
|
std::set<std::string>* dependent_vars) {
|
|
|
|
|
auto& block = input.blocks(block_id);
|
|
|
|
|
auto& ops = block.ops();
|
|
|
|
|
|
|
|
|
@ -90,11 +90,11 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
std::vector<bool> should_run;
|
|
|
|
|
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
|
|
|
|
auto& op_desc = *op_iter;
|
|
|
|
|
if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
|
|
|
|
|
if (IsTarget(op_desc) || HasDependentVar(op_desc, *dependent_vars)) {
|
|
|
|
|
// insert its input to the dependency graph
|
|
|
|
|
for (auto& var : op_desc.inputs()) {
|
|
|
|
|
for (auto& argu : var.arguments()) {
|
|
|
|
|
dependent_vars.insert(argu);
|
|
|
|
|
dependent_vars->insert(argu);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
should_run.push_back(true);
|
|
|
|
@ -138,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
|
|
|
|
|
// output_block_id is the idx of the current block in the output desc
|
|
|
|
|
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
|
|
|
|
|
sub_block_dependent_vars);
|
|
|
|
|
&sub_block_dependent_vars);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -181,7 +181,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
|
|
|
|
|
std::set<std::string> dependent_vars;
|
|
|
|
|
output->clear_blocks();
|
|
|
|
|
prune_impl(input, output, 0, -1, dependent_vars);
|
|
|
|
|
prune_impl(input, output, 0, -1, &dependent_vars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void inference_optimize_impl(proto::ProgramDesc* input, int block_id) {
|
|
|
|
|