|
|
|
@ -81,7 +81,7 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return (ptr_ + i_) &= &(*rhs);
|
|
|
|
|
return (ptr_ + i_) != &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
@ -97,7 +97,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
|
|
|
|
|
|
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
|
i_ = ++j_ / post_ % n_;
|
|
|
|
|
i_ = (++j_ / post_) % n_;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -108,7 +108,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return (ptr_ + i_) &= &(*rhs);
|
|
|
|
|
return (ptr_ + i_) != &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
@ -129,14 +129,14 @@ struct RowwiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
typedef thrust::iterator_adaptor<
|
|
|
|
|
RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
__host__ __device__ RowwiseTransformIterator(const T* x, int n)
|
|
|
|
|
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
|
|
|
|
|
: super_t(x), begin_(x), n_(n){};
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
unsigned int n_;
|
|
|
|
|
const T* begin_;
|
|
|
|
|
__host__ __device__ typename super_t::reference dereference() const {
|
|
|
|
|
HOSTDEVICE typename super_t::reference dereference() const {
|
|
|
|
|
return *(begin_ + (this->base() - begin_) % n_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -149,7 +149,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
typedef thrust::iterator_adaptor<
|
|
|
|
|
MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
__host__ __device__ MidWiseTransformIterator(const T* x, int n, int post)
|
|
|
|
|
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
|
|
|
|
|
: super_t(x), begin_(x), n_(n), post_(post){};
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
@ -157,7 +157,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
unsigned int post_;
|
|
|
|
|
unsigned int n_;
|
|
|
|
|
const T* begin_;
|
|
|
|
|
__host__ __device__ typename super_t::reference dereference() const {
|
|
|
|
|
HOSTDEVICE typename super_t::reference dereference() const {
|
|
|
|
|
return *(begin_ + (((this->base() - begin_) / post_) % n_));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -166,7 +166,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
template <typename Functor, typename T, typename Place>
|
|
|
|
|
struct TransformFunctor {
|
|
|
|
|
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
|
|
|
|
|
framework::Tensor* z, const framework::ExecutionContext& ctx,
|
|
|
|
|
framework::Tensor* z, const platform::DeviceContext& ctx,
|
|
|
|
|
Functor func)
|
|
|
|
|
: x_(x->data<T>()),
|
|
|
|
|
y_(y->data<T>()),
|
|
|
|
@ -177,26 +177,26 @@ struct TransformFunctor {
|
|
|
|
|
|
|
|
|
|
inline void Run() const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_.device_context(), x_, x_ + nx_, y_, z_, func_);
|
|
|
|
|
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void RunRowWise(int n, int pre) const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_.device_context(), x_, x_ + nx_,
|
|
|
|
|
RowwiseTransformIterator<T, Place>(y_, n), z_, func_);
|
|
|
|
|
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_,
|
|
|
|
|
func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void RunMidWise(int n, int pre, int post) const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_.device_context(), x_, x_ + nx_,
|
|
|
|
|
MidWiseTransformIterator<T, Place>(y_, n, post), z_, func_);
|
|
|
|
|
trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post),
|
|
|
|
|
z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* x_;
|
|
|
|
|
const T* y_;
|
|
|
|
|
T* z_;
|
|
|
|
|
int64_t nx_;
|
|
|
|
|
const framework::ExecutionContext& ctx_;
|
|
|
|
|
const platform::DeviceContext& ctx_;
|
|
|
|
|
Functor func_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|