|
|
@ -131,61 +131,6 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
|
|
|
|
int post_;
|
|
|
|
int post_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
|
|
|
class ElementIterator;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Fixed(zcd) : Only support 2D
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
class ElementIterator<T, platform::CPUDeviceContext> {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
ElementIterator(const T* ptr, int t_m, int t_n, int m, int n)
|
|
|
|
|
|
|
|
: ptr_(ptr),
|
|
|
|
|
|
|
|
index_(0),
|
|
|
|
|
|
|
|
i_(0),
|
|
|
|
|
|
|
|
j_(0),
|
|
|
|
|
|
|
|
t_m_(t_m),
|
|
|
|
|
|
|
|
t_n_(t_n),
|
|
|
|
|
|
|
|
m_(m),
|
|
|
|
|
|
|
|
n_(n) {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ElementIterator<T, platform::CPUDeviceContext>& operator++() {
|
|
|
|
|
|
|
|
++j_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ((j_ == n_)) {
|
|
|
|
|
|
|
|
j_ = 0;
|
|
|
|
|
|
|
|
++i_;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
int t_i = (t_m_ == 1) ? 0 : i_;
|
|
|
|
|
|
|
|
int t_j = (t_n_ == 1) ? 0 : j_;
|
|
|
|
|
|
|
|
index_ = t_i * t_n_ + t_j;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool operator==(
|
|
|
|
|
|
|
|
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
|
|
|
|
|
|
|
|
return (ptr_ + index_) == &(*rhs);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
|
|
|
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
|
|
|
|
|
|
|
|
return (ptr_ + index_) != &(*rhs);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[index_]; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
// t_m_ == m_ || t_n_ == n_ || (t_m_ == 1 && t_m_ == 1)
|
|
|
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
|
|
|
int index_;
|
|
|
|
|
|
|
|
int i_;
|
|
|
|
|
|
|
|
int j_;
|
|
|
|
|
|
|
|
int64_t t_m_;
|
|
|
|
|
|
|
|
int64_t t_n_;
|
|
|
|
|
|
|
|
int64_t m_;
|
|
|
|
|
|
|
|
int64_t n_;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
#ifdef __NVCC__
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
|
|
|
|
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
|
|
|
|