|
|
|
@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
int block_id) {
|
|
|
|
|
// TODO(tonyyang-svail):
|
|
|
|
|
// - will change to use multiple blocks for RNN op and Cond Op
|
|
|
|
|
int GetSubBlockIndex(const proto::OpDesc& op_desc) {
|
|
|
|
|
for (auto& attr : op_desc.attrs()) {
|
|
|
|
|
if (attr.type() == proto::AttrType::BLOCK) {
|
|
|
|
|
PADDLE_ENFORCE(attr.has_block_idx());
|
|
|
|
|
return attr.block_idx();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasSubBlock(const proto::OpDesc& op_desc) {
|
|
|
|
|
return GetSubBlockIndex(op_desc) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// block_id is the idx of the current block in the input desc
|
|
|
|
|
// parent_block_id is the idx of the parent of the current block
|
|
|
|
|
// in the output desc, -1 means the current block is global block
|
|
|
|
|
// dependent_vars is passed recursively from the parent block to
|
|
|
|
|
// the child block to help pruning
|
|
|
|
|
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
int block_id, int parent_block_id,
|
|
|
|
|
std::set<std::string>& dependent_vars) {
|
|
|
|
|
auto& block = input.blocks(block_id);
|
|
|
|
|
auto& ops = block.ops();
|
|
|
|
|
|
|
|
|
@ -72,11 +89,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
expect_fetch = (op_desc.type() == kFetchOpType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<std::string> dependent_vars;
|
|
|
|
|
std::vector<bool> should_run;
|
|
|
|
|
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
|
|
|
|
auto& op_desc = *op_iter;
|
|
|
|
|
|
|
|
|
|
if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
|
|
|
|
|
// insert its input to the dependency graph
|
|
|
|
|
for (auto& var : op_desc.inputs()) {
|
|
|
|
@ -84,7 +99,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
dependent_vars.insert(argu);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
should_run.push_back(true);
|
|
|
|
|
} else {
|
|
|
|
|
should_run.push_back(false);
|
|
|
|
@ -95,19 +109,48 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
// we reverse the should_run vector
|
|
|
|
|
std::reverse(should_run.begin(), should_run.end());
|
|
|
|
|
|
|
|
|
|
*output = input;
|
|
|
|
|
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
|
|
|
|
|
//*output = input;
|
|
|
|
|
// copy the current block from input to output
|
|
|
|
|
auto* block_field = output->mutable_blocks();
|
|
|
|
|
*block_field->Add() = input.blocks(block_id);
|
|
|
|
|
|
|
|
|
|
int output_block_id = output->blocks_size() - 1;
|
|
|
|
|
auto* output_block = output->mutable_blocks(output_block_id);
|
|
|
|
|
output_block->set_idx = output_block_id;
|
|
|
|
|
output_block->set_parent_idx = parent_block_id;
|
|
|
|
|
|
|
|
|
|
auto* op_field = output_block->mutable_ops();
|
|
|
|
|
op_field->Clear();
|
|
|
|
|
for (size_t i = 0; i < should_run.size(); ++i) {
|
|
|
|
|
if (should_run[i]) {
|
|
|
|
|
*op_field->Add() = input.blocks(block_id).ops(i);
|
|
|
|
|
auto* op = op_field->Add();
|
|
|
|
|
*op = input.blocks(block_id).ops(i);
|
|
|
|
|
if (HasSubBlock(*op)) {
|
|
|
|
|
// create sub_block_dependent_vars here to help prune the sub block
|
|
|
|
|
std::set<std::string> sub_block_dependent_vars;
|
|
|
|
|
for (auto& var : op.inputs()) {
|
|
|
|
|
for (auto& argu : var.arguments()) {
|
|
|
|
|
sub_block_dependent_vars.insert(argu);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto& var : op.outputs()) {
|
|
|
|
|
for (auto& argu : var.arguments()) {
|
|
|
|
|
sub_block_dependent_vars.insert(argu);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
|
|
|
|
|
// output_block_id is the idx of the current block in the output desc
|
|
|
|
|
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
|
|
|
|
|
sub_block_dependent_vars);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// remove the VarDescs in BlockDesc that are not referenced in
|
|
|
|
|
// the pruned OpDescs
|
|
|
|
|
std::unordered_map<std::string, proto::VarDesc> var_map;
|
|
|
|
|
auto* var_field = output->mutable_blocks(block_id)->mutable_vars();
|
|
|
|
|
auto* var_field = output->mutable_blocks(output_block_id)->mutable_vars();
|
|
|
|
|
for (const auto& var : *var_field) {
|
|
|
|
|
var_map[var.name()] = var;
|
|
|
|
|
}
|
|
|
|
@ -118,14 +161,14 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
auto& input_field = op.inputs();
|
|
|
|
|
for (auto& input_var : input_field) {
|
|
|
|
|
for (auto& arg : input_var.arguments()) {
|
|
|
|
|
*var_field->Add() = var_map[arg];
|
|
|
|
|
*var_field->Add() = var_map.at(arg);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// add VarDescs of all output arguments for each OpDesc
|
|
|
|
|
auto& output_field = op.outputs();
|
|
|
|
|
for (auto& output_var : output_field) {
|
|
|
|
|
for (auto& arg : output_var.arguments()) {
|
|
|
|
|
*var_field->Add() = var_map[arg];
|
|
|
|
|
*var_field->Add() = var_map.at(arg);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -133,7 +176,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
|
|
|
|
|
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
|
|
|
|
|
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
|
|
|
|
|
prune_impl(input, output, 0);
|
|
|
|
|
prune_impl(input, output, 0, -1, {});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void inference_optimize_impl(const proto::ProgramDesc& input,
|
|
|
|
|