|
|
|
@ -113,7 +113,6 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetOpRole(const proto::OpDesc& op_desc) {
|
|
|
|
|
// The op role >= 0, so -1 is used to indicate "NotFound".
|
|
|
|
|
for (auto& attr : op_desc.attrs()) {
|
|
|
|
|
if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
@ -124,7 +123,10 @@ int GetOpRole(const proto::OpDesc& op_desc) {
|
|
|
|
|
return attr.i();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return -1;
|
|
|
|
|
// If attr op_role is not found, it may be operator created in c++ test, like
|
|
|
|
|
// prune_test.cc. In that case, the op_role should be defaut value, which is
|
|
|
|
|
// kNotSpecified.
|
|
|
|
|
return static_cast<int>(OpRole::kNotSpecified);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AppendOpInputVarNames(const proto::OpDesc& op_desc,
|
|
|
|
@ -145,6 +147,16 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int FindMapByValue(const std::map<int, int>& m, int val) {
|
|
|
|
|
// The content in map should be >= 0, so -1 is used to indicate "NotFound".
|
|
|
|
|
for (auto& pair : m) {
|
|
|
|
|
if (pair.second == val) {
|
|
|
|
|
return pair.first;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// block_id is the idx of the current block in the input desc
|
|
|
|
|
// parent_block_id is the idx of the parent of the current block
|
|
|
|
|
// in the output desc, -1 means the current block is global block
|
|
|
|
@ -153,30 +165,41 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
|
|
|
|
|
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
int block_id, int parent_block_id,
|
|
|
|
|
std::unordered_set<std::string>* dependent_vars,
|
|
|
|
|
const std::set<std::string> feed_var_names) {
|
|
|
|
|
const std::set<std::string> feed_var_names,
|
|
|
|
|
std::map<int, int>* pruned_origin_block_id_map) {
|
|
|
|
|
auto& block = input.blocks(block_id);
|
|
|
|
|
auto& ops = block.ops();
|
|
|
|
|
|
|
|
|
|
bool expect_feed = true;
|
|
|
|
|
for (auto& op_desc : ops) {
|
|
|
|
|
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
|
|
|
|
|
"All FeedOps are at the beginning of the ProgramDesc");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op_desc.type() != kFeedOpType || expect_feed, true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"All FeedOps are at the beginning of the ProgramDesc"));
|
|
|
|
|
expect_feed = (op_desc.type() == kFeedOpType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool expect_fetch = true;
|
|
|
|
|
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
|
|
|
|
auto& op_desc = *op_iter;
|
|
|
|
|
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
|
|
|
|
|
"All FetchOps must at the end of the ProgramDesc");
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.type() != kFetchOpType || expect_fetch, true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"All FetchOps must at the end of the ProgramDesc"));
|
|
|
|
|
expect_fetch = (op_desc.type() == kFetchOpType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<bool> should_run;
|
|
|
|
|
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
|
|
|
|
auto& op_desc = *op_iter;
|
|
|
|
|
if (IsTarget(op_desc) || HasDependentOutputVar(op_desc, *dependent_vars)) {
|
|
|
|
|
// insert its input to the dependency graph
|
|
|
|
|
|
|
|
|
|
if (IsTarget(op_desc) ||
|
|
|
|
|
(HasDependentOutputVar(op_desc, *dependent_vars) &&
|
|
|
|
|
(GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) {
|
|
|
|
|
// NOTE(zhiqiu): since optimize op takes the trainable parameters as
|
|
|
|
|
// inputs and output, it may introduce wrong dependency graph.
|
|
|
|
|
// For train mode, the optimize op should be in targets, so is not need
|
|
|
|
|
// and not right to mark optimize op by its outputs.
|
|
|
|
|
// For eval / infer mode, there is no optimize op in program.
|
|
|
|
|
for (auto& var : op_desc.inputs()) {
|
|
|
|
|
for (auto& argu : var.arguments()) {
|
|
|
|
|
if (feed_var_names.count(argu) == 0) {
|
|
|
|
@ -203,6 +226,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
output_block->set_idx(output_block_id);
|
|
|
|
|
output_block->set_parent_idx(parent_block_id);
|
|
|
|
|
|
|
|
|
|
(*pruned_origin_block_id_map)[output_block_id] = block_id;
|
|
|
|
|
|
|
|
|
|
auto* op_field = output_block->mutable_ops();
|
|
|
|
|
op_field->Clear();
|
|
|
|
|
for (size_t i = 0; i < should_run.size(); ++i) {
|
|
|
|
@ -244,7 +269,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
|
|
|
|
|
// output_block_id is the idx of the current block in the output desc
|
|
|
|
|
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
|
|
|
|
|
&sub_block_dependent_vars, feed_var_names);
|
|
|
|
|
&sub_block_dependent_vars, feed_var_names,
|
|
|
|
|
pruned_origin_block_id_map);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -284,22 +310,33 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
|
|
|
|
|
void Prune(const proto::ProgramDesc& input,
|
|
|
|
|
const std::set<std::string>& feed_var_names,
|
|
|
|
|
proto::ProgramDesc* output) {
|
|
|
|
|
std::map<int, int> Prune(const proto::ProgramDesc& input,
|
|
|
|
|
const std::set<std::string>& feed_var_names,
|
|
|
|
|
proto::ProgramDesc* output) {
|
|
|
|
|
std::unordered_set<std::string> dependent_vars;
|
|
|
|
|
output->clear_blocks();
|
|
|
|
|
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int FindMapByValue(const std::map<int, int>& m, int val) {
|
|
|
|
|
// The content in map should be >= 0, so -1 is used to indicate "NotFound".
|
|
|
|
|
for (auto& pair : m) {
|
|
|
|
|
if (pair.second == val) {
|
|
|
|
|
return pair.first;
|
|
|
|
|
std::map<int, int> pruned_origin_block_id_map;
|
|
|
|
|
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names,
|
|
|
|
|
&pruned_origin_block_id_map);
|
|
|
|
|
// update subblock idx
|
|
|
|
|
for (int i = 0; i < output->blocks_size(); i++) {
|
|
|
|
|
auto* pruned = output->mutable_blocks(i);
|
|
|
|
|
auto* ops = pruned->mutable_ops();
|
|
|
|
|
for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) {
|
|
|
|
|
auto& op_desc = *op_iter;
|
|
|
|
|
if (HasSubBlock(op_desc)) {
|
|
|
|
|
int origin_sub_idx = GetSubBlockIndex(op_desc);
|
|
|
|
|
auto sub_idx =
|
|
|
|
|
FindMapByValue(pruned_origin_block_id_map, origin_sub_idx);
|
|
|
|
|
PADDLE_ENFORCE_NE(sub_idx, -1,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"The origin sub block id should be found in "
|
|
|
|
|
"pruned_progin_block_id_map"));
|
|
|
|
|
SetSubBlockIndex(&op_desc, sub_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return -1;
|
|
|
|
|
return pruned_origin_block_id_map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
|
|
|
|
@ -348,8 +385,8 @@ void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
|
|
|
|
|
var_names.insert(op_output_vars.begin(), op_output_vars.end());
|
|
|
|
|
for (const auto& name : var_names) {
|
|
|
|
|
if (var_map.count(name)) {
|
|
|
|
|
// NOTE(zhiqiu): For operator in a conditional block, the related vars may
|
|
|
|
|
// not exist in current block, but in its futher block.
|
|
|
|
|
// NOTE(zhiqiu): For operator in a conditional block, the related vars
|
|
|
|
|
// may not exist in current block, but in its futher block.
|
|
|
|
|
*pruned_vars->Add() = var_map[name];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -389,6 +426,7 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
|
|
|
|
|
|
|
|
|
|
proto::ProgramDesc pruned_desc;
|
|
|
|
|
pruned_desc.clear_blocks();
|
|
|
|
|
|
|
|
|
|
// Step 2. Prune backward for each block.
|
|
|
|
|
for (size_t i = 0; i < origin_clone.Size(); i++) {
|
|
|
|
|
auto pruned = proto::BlockDesc();
|
|
|
|
|