|
|
|
@ -16,6 +16,11 @@
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
#include <thrust/iterator/iterator_adaptor.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
@ -54,6 +59,153 @@ inline void get_mid_dims(const framework::DDim& x_dims,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
class RowwiseTransformIterator;
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
class MidWiseTransformIterator;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
public:
|
|
|
|
|
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
|
|
|
|
|
|
|
|
|
|
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
|
++i_;
|
|
|
|
|
i_ %= n_;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator==(
|
|
|
|
|
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return (ptr_ + i_) == &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return (ptr_ + i_) != &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int64_t n_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
public:
|
|
|
|
|
MidWiseTransformIterator(const T* ptr, int n, int post)
|
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
|
|
|
|
|
|
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
|
i_ = (++j_ / post_) % n_;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator==(
|
|
|
|
|
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return (ptr_ + i_) == &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return (ptr_ + i_) != &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int64_t j_;
|
|
|
|
|
int64_t n_;
|
|
|
|
|
int post_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
template <typename T>
|
|
|
|
|
class RowwiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
: public thrust::iterator_adaptor<
|
|
|
|
|
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
|
|
|
|
|
public:
|
|
|
|
|
typedef thrust::iterator_adaptor<
|
|
|
|
|
RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
|
|
|
|
|
: super_t(x), begin_(x), n_(n){};
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
unsigned int n_;
|
|
|
|
|
const T* begin_;
|
|
|
|
|
HOSTDEVICE typename super_t::reference dereference() const {
|
|
|
|
|
return *(begin_ + (this->base() - begin_) % n_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
: public thrust::iterator_adaptor<
|
|
|
|
|
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
|
|
|
|
|
public:
|
|
|
|
|
typedef thrust::iterator_adaptor<
|
|
|
|
|
MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
|
|
|
|
|
: super_t(x), begin_(x), n_(n), post_(post){};
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
unsigned int post_;
|
|
|
|
|
unsigned int n_;
|
|
|
|
|
const T* begin_;
|
|
|
|
|
HOSTDEVICE typename super_t::reference dereference() const {
|
|
|
|
|
return *(begin_ + (((this->base() - begin_) / post_) % n_));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename T, typename Place>
|
|
|
|
|
class TransformFunctor {
|
|
|
|
|
public:
|
|
|
|
|
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
|
|
|
|
|
framework::Tensor* z, const platform::DeviceContext& ctx,
|
|
|
|
|
Functor func)
|
|
|
|
|
: x_(x->data<T>()),
|
|
|
|
|
y_(y->data<T>()),
|
|
|
|
|
z_(z->mutable_data<T>(ctx.GetPlace())),
|
|
|
|
|
nx_(x->numel()),
|
|
|
|
|
ctx_(ctx),
|
|
|
|
|
func_(func) {}
|
|
|
|
|
|
|
|
|
|
inline void Run() const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void RunRowWise(int n, int pre) const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_,
|
|
|
|
|
func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void RunMidWise(int n, int pre, int post) const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post),
|
|
|
|
|
z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* x_;
|
|
|
|
|
const T* y_;
|
|
|
|
|
T* z_;
|
|
|
|
|
int64_t nx_;
|
|
|
|
|
const platform::DeviceContext& ctx_;
|
|
|
|
|
Functor func_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define EIGEN_FUNCTOR(name, eigen_op) \
|
|
|
|
|
struct Eigen##name##Functor { \
|
|
|
|
|
template <typename Place, typename T> \
|
|
|
|
|