diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 09ab42b501..14da42a786 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -70,9 +70,7 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> { RowwiseTransformIterator<T, platform::CPUPlace>& operator++() { ++i_; - if (i_ == n_) { - i_ = 0; - } + i_ %= n_; return *this; } @@ -90,7 +88,7 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> { const T* ptr_; int i_; - int n_; + int64_t n_; }; template <typename T> @@ -99,14 +97,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} MidWiseTransformIterator<T, platform::CPUPlace>& operator++() { - ++j_; - if (j_ == post_) { - j_ = 0; - ++i_; - if (i_ == n_) { - i_ = 0; - } - } + i_ = ++j_ / post_ % n_; return *this; } @@ -124,8 +115,8 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { const T* ptr_; int i_; - int j_; - int n_; + int64_t j_; + int64_t n_; int post_; };