|
|
|
@ -18,6 +18,10 @@
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
#include <thrust/iterator/iterator_adaptor.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -74,12 +78,12 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
bool operator==(
|
|
|
|
|
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) == &(*rhs);
|
|
|
|
|
return (ptr_ + i_) == &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) &= &(*rhs);
|
|
|
|
|
return (ptr_ + i_) &= &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
@ -108,12 +112,12 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
bool operator==(
|
|
|
|
|
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) == &(*rhs);
|
|
|
|
|
return (ptr_ + i_) == &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) &= &(*rhs);
|
|
|
|
|
return (ptr_ + i_) &= &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
@ -125,6 +129,49 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
int post_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct RowwiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
: public thrust::iterator_adaptor<
|
|
|
|
|
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
|
|
|
|
|
public:
|
|
|
|
|
typedef thrust::iterator_adaptor<
|
|
|
|
|
RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
__host__ __device__ RowwiseTransformIterator(const T* x, int n)
|
|
|
|
|
: super_t(x), begin_(x), n_(n){};
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
unsigned int n_;
|
|
|
|
|
const T* begin_;
|
|
|
|
|
__host__ __device__ typename super_t::reference dereference() const {
|
|
|
|
|
return *(begin_ + (this->base() - begin_) % n_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
: public thrust::iterator_adaptor<
|
|
|
|
|
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
|
|
|
|
|
public:
|
|
|
|
|
typedef thrust::iterator_adaptor<
|
|
|
|
|
MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
__host__ __device__ MidWiseTransformIterator(const T* x, int n, int post)
|
|
|
|
|
: super_t(x), begin_(x), n_(n), post_(post){};
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
unsigned int post_;
|
|
|
|
|
unsigned int n_;
|
|
|
|
|
const T* begin_;
|
|
|
|
|
__host__ __device__ typename super_t::reference dereference() const {
|
|
|
|
|
return *(begin_ + (((this->base() - begin_) / post_) % n_));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename T, typename Place>
|
|
|
|
|
struct TransformFunctor {
|
|
|
|
|
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
|
|
|
|
|