|
|
|
@ -60,12 +60,13 @@ inline void get_mid_dims(const framework::DDim& x_dims,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
struct RowwiseTransformIterator;
|
|
|
|
|
class RowwiseTransformIterator;
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
struct MidWiseTransformIterator;
|
|
|
|
|
class MidWiseTransformIterator;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
class RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
public:
|
|
|
|
|
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
|
|
|
|
|
|
|
|
|
|
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
@ -86,13 +87,15 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int64_t n_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
class MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
public:
|
|
|
|
|
MidWiseTransformIterator(const T* ptr, int n, int post)
|
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
|
|
|
|
|
@ -113,6 +116,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int64_t j_;
|
|
|
|
@ -122,7 +126,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct RowwiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
class RowwiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
: public thrust::iterator_adaptor<
|
|
|
|
|
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
|
|
|
|
|
public:
|
|
|
|
@ -142,7 +146,7 @@ struct RowwiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
class MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
: public thrust::iterator_adaptor<
|
|
|
|
|
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
|
|
|
|
|
public:
|
|
|
|
@ -164,7 +168,8 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename T, typename Place>
|
|
|
|
|
struct TransformFunctor {
|
|
|
|
|
class TransformFunctor {
|
|
|
|
|
public:
|
|
|
|
|
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
|
|
|
|
|
framework::Tensor* z, const platform::DeviceContext& ctx,
|
|
|
|
|
Functor func)
|
|
|
|
@ -192,6 +197,7 @@ struct TransformFunctor {
|
|
|
|
|
z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* x_;
|
|
|
|
|
const T* y_;
|
|
|
|
|
T* z_;
|
|
|
|
|