|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
@ -54,6 +55,113 @@ inline void get_mid_dims(const framework::DDim& x_dims,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
struct RowwiseTransformIterator;
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
struct MidWiseTransformIterator;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct RowwiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
|
|
|
|
|
|
|
|
|
|
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
|
++i_;
|
|
|
|
|
if (i_ == n_) {
|
|
|
|
|
i_ = 0;
|
|
|
|
|
}
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator==(
|
|
|
|
|
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) == &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) &= &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
|
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int n_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MidWiseTransformIterator<T, platform::CPUPlace> {
|
|
|
|
|
MidWiseTransformIterator(const T* ptr, int n, int post)
|
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
|
|
|
|
|
|
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
|
|
|
|
|
++j_;
|
|
|
|
|
if (j_ == post_) {
|
|
|
|
|
j_ = 0;
|
|
|
|
|
++i_;
|
|
|
|
|
if (i_ == n_) {
|
|
|
|
|
i_ = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator==(
|
|
|
|
|
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) == &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator!=(
|
|
|
|
|
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
|
|
|
|
|
return &(this->operator*()) &= &(*rhs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T& operator*() { return ptr_[i_]; }
|
|
|
|
|
|
|
|
|
|
const T* ptr_;
|
|
|
|
|
int i_;
|
|
|
|
|
int j_;
|
|
|
|
|
int n_;
|
|
|
|
|
int post_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename T, typename Place>
|
|
|
|
|
struct TransformFunctor {
|
|
|
|
|
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
|
|
|
|
|
framework::Tensor* z, const framework::ExecutionContext& ctx,
|
|
|
|
|
Functor func)
|
|
|
|
|
: x_(x->data<T>()),
|
|
|
|
|
y_(y->data<T>()),
|
|
|
|
|
z_(z->mutable_data<T>(ctx.GetPlace())),
|
|
|
|
|
nx_(x->numel()),
|
|
|
|
|
ctx_(ctx),
|
|
|
|
|
func_(func) {}
|
|
|
|
|
|
|
|
|
|
inline void Run() const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_.device_context(), x_, x_ + nx_, y_, z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void RunRowWise(int n, int pre) const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_.device_context(), x_, x_ + nx_,
|
|
|
|
|
RowwiseTransformIterator<T, Place>(y_, n), z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void RunMidWise(int n, int pre, int post) const {
|
|
|
|
|
platform::Transform<Place> trans;
|
|
|
|
|
trans(ctx_.device_context(), x_, x_ + nx_,
|
|
|
|
|
MidWiseTransformIterator<T, Place>(y_, n, post), z_, func_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* x_;
|
|
|
|
|
const T* y_;
|
|
|
|
|
T* z_;
|
|
|
|
|
int64_t nx_;
|
|
|
|
|
const framework::ExecutionContext& ctx_;
|
|
|
|
|
Functor func_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define EIGEN_FUNCTOR(name, eigen_op) \
|
|
|
|
|
struct Eigen##name##Functor { \
|
|
|
|
|
template <typename Place, typename T> \
|
|
|
|
|