|
|
@ -71,7 +71,9 @@ class RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
|
|
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
++i_;
|
|
|
|
++i_;
|
|
|
|
i_ %= n_;
|
|
|
|
if (UNLIKELY(i_ == n_)) {
|
|
|
|
|
|
|
|
i_ = 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
return *this;
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -100,7 +102,12 @@ class MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
|
|
|
|
|
|
|
|
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
i_ = (++j_ / post_) % n_;
|
|
|
|
++j_;
|
|
|
|
|
|
|
|
i_ = j_ / post_;
|
|
|
|
|
|
|
|
if (UNLIKELY(i_ == n_)) {
|
|
|
|
|
|
|
|
j_ = 0;
|
|
|
|
|
|
|
|
i_ = 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
return *this;
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|