|
|
|
@ -70,9 +70,7 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
|
++i_;
|
|
|
|
|
if (i_ == n_) {
|
|
|
|
|
i_ = 0;
|
|
|
|
|
}
|
|
|
|
|
i_ %= n_;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -90,7 +88,7 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int n_;
|
|
|
|
|
int64_t n_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -99,14 +97,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
|
|
|
|
|
|
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
|
++j_;
|
|
|
|
|
if (j_ == post_) {
|
|
|
|
|
j_ = 0;
|
|
|
|
|
++i_;
|
|
|
|
|
if (i_ == n_) {
|
|
|
|
|
i_ = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
i_ = ++j_ / post_ % n_;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -124,8 +115,8 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int j_;
|
|
|
|
|
int n_;
|
|
|
|
|
int64_t j_;
|
|
|
|
|
int64_t n_;
|
|
|
|
|
int post_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|