|
|
@ -80,6 +80,70 @@ void InferShape(const std::shared_ptr<Scope<>& scope) {
|
|
|
|
void CopyInSeqToOut();
|
|
|
|
void CopyInSeqToOut();
|
|
|
|
```
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
## 根据长度排序
|
|
|
|
|
|
|
|
按照长度排序后,从前往后的时间步的batch size会自然地递减,这是 Net 支持的
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
比如:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
origin:
|
|
|
|
|
|
|
|
xxxx
|
|
|
|
|
|
|
|
xx
|
|
|
|
|
|
|
|
xxx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-> sorted:
|
|
|
|
|
|
|
|
xx
|
|
|
|
|
|
|
|
xxx
|
|
|
|
|
|
|
|
xxxx
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
经过 `SegmentInputs` 之后,每个会有4个时间步,每个时间步的输入如下(纵向排列)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
0 1 2 3
|
|
|
|
|
|
|
|
x x x x
|
|
|
|
|
|
|
|
x x x
|
|
|
|
|
|
|
|
x x
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
为了追踪排序前后序列的变化,这里用
|
|
|
|
|
|
|
|
```c++
|
|
|
|
|
|
|
|
struct SortedSeqItem {
|
|
|
|
|
|
|
|
void *start{nullptr};
|
|
|
|
|
|
|
|
void *end{nullptr};
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<SortedSeqItem> sorted_seqs;
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
来追踪序列排序后的位置。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
对比现有设计,只需要修改 `SegmentInputs` 和 `ConcatOutputs` 两个接口,此外添加一个 `SortBySeqLen` 的接口,
|
|
|
|
|
|
|
|
就可以支持上述变长序列,下面详细介绍。
|
|
|
|
|
|
|
|
## SegmentInputs
|
|
|
|
|
|
|
|
`SegmentInputs` 会依赖 `sorted_seqs` 的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
即下面的转变:
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
origin:
|
|
|
|
|
|
|
|
xxxx
|
|
|
|
|
|
|
|
xx
|
|
|
|
|
|
|
|
xxx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
\ /
|
|
|
|
|
|
|
|
*
|
|
|
|
|
|
|
|
0 1 2 3
|
|
|
|
|
|
|
|
x x x x
|
|
|
|
|
|
|
|
x x x
|
|
|
|
|
|
|
|
x x
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
## ConcatOutputs
|
|
|
|
|
|
|
|
`ConcatOutputs` 需要
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- 将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
|
|
|
|
|
|
|
|
- 将序列折叠,在batch维度上展开
|
|
|
|
|
|
|
|
|
|
|
|
## 参考文献
|
|
|
|
## 参考文献
|
|
|
|
1. [Tensorflow Bucketing](https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing)
|
|
|
|
1. [Tensorflow Bucketing](https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing)
|
|
|
|
2. [mxnet Bucketing](http://mxnet.io/how_to/bucketing.html)
|
|
|
|
2. [mxnet Bucketing](http://mxnet.io/how_to/bucketing.html)
|
|
|
|