Prune Design Doc (#4732)
* Create prune.md * modification based on comment * remove insertion * rename id to block_id * Update prune.md * formattingrevert-4814-Add_sequence_project_op
parent
831927d58c
commit
3ca3a200ab
@ -0,0 +1,63 @@
|
|||||||
|
# Prune
|
||||||
|
|
||||||
|
## Motivation
|
||||||
|
|
||||||
|
We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement
|
||||||
|
`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc`
|
||||||
|
and generate a pruned `ProgramDesc`.
|
||||||
|
|
||||||
|
## Challenge
|
||||||
|
|
||||||
|
Pruning need to support both variables and operators being evaluation targets. Consider the following
|
||||||
|
different situations.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Case 1: run foward pass.
|
||||||
|
cost_np = session.run(target=cost)
|
||||||
|
# Case 2: run backward passing.
|
||||||
|
opts_np, _ = session.run(target=[cost, opt])
|
||||||
|
# Case 3: run checkpointing
|
||||||
|
_ = session.run(target=checkpoint)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Solution
|
||||||
|
|
||||||
|
To support evaluation of operators, we add `is_target` field in the `OpDesc`.
|
||||||
|
|
||||||
|
```c++
|
||||||
|
message OpDesc {
|
||||||
|
required string type = 3;
|
||||||
|
repeated Var inputs = 1;
|
||||||
|
repeated Var outputs = 2;
|
||||||
|
repeated Attr attrs = 4;
|
||||||
|
optional bool is_target = 5 [ default = false ];
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599).
|
||||||
|
For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being
|
||||||
|
`fetch_op`'s input. Then we also set `fetch_op` is a target.
|
||||||
|
|
||||||
|
### Algorithm
|
||||||
|
|
||||||
|
If an operator needs to be run, it must fall into one of the following cases:
|
||||||
|
|
||||||
|
1. It is the target.
|
||||||
|
2. It is depended by some other ops, meaning its output is some other op's input.
|
||||||
|
|
||||||
|
The first case can be checked by `op_desc.is_traget()` . The second case can be implement as
|
||||||
|
|
||||||
|
```c++
|
||||||
|
bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
|
||||||
|
for (auto& var : op_desc.outputs()) {
|
||||||
|
for (auto& argu : var.arguments()) {
|
||||||
|
if (dependent_vars.count(argu) != 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc).
|
Loading…
Reference in new issue