You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
1.9 KiB
64 lines
1.9 KiB
7 years ago
|
# Prune
|
||
|
|
||
|
## Motivation
|
||
|
|
||
|
We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement
|
||
|
`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc`
|
||
|
and generate a pruned `ProgramDesc`.
|
||
|
|
||
|
## Challenge
|
||
|
|
||
|
Pruning need to support both variables and operators being evaluation targets. Consider the following
|
||
|
different situations.
|
||
|
|
||
|
```python
|
||
|
# Case 1: run foward pass.
|
||
|
cost_np = session.run(target=cost)
|
||
|
# Case 2: run backward passing.
|
||
|
opts_np, _ = session.run(target=[cost, opt])
|
||
|
# Case 3: run checkpointing
|
||
|
_ = session.run(target=checkpoint)
|
||
|
```
|
||
|
|
||
|
## Solution
|
||
|
|
||
|
To support evaluation of operators, we add `is_target` field in the `OpDesc`.
|
||
|
|
||
|
```c++
|
||
|
message OpDesc {
|
||
|
required string type = 3;
|
||
|
repeated Var inputs = 1;
|
||
|
repeated Var outputs = 2;
|
||
|
repeated Attr attrs = 4;
|
||
|
optional bool is_target = 5 [ default = false ];
|
||
|
};
|
||
|
```
|
||
|
|
||
|
To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599).
|
||
|
For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being
|
||
|
`fetch_op`'s input. Then we also set `fetch_op` is a target.
|
||
|
|
||
|
### Algorithm
|
||
|
|
||
|
If an operator needs to be run, it must fall into one of the following cases:
|
||
|
|
||
|
1. It is the target.
|
||
|
2. It is depended by some other ops, meaning its output is some other op's input.
|
||
|
|
||
|
The first case can be checked by `op_desc.is_traget()` . The second case can be implement as
|
||
|
|
||
|
```c++
|
||
|
bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
|
||
|
for (auto& var : op_desc.outputs()) {
|
||
|
for (auto& argu : var.arguments()) {
|
||
|
if (dependent_vars.count(argu) != 0) {
|
||
|
return true;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return false;
|
||
|
}
|
||
|
```
|
||
|
|
||
|
Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc).
|