|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
#include "paddle/framework/ddim.h"
|
|
|
|
|
#include "paddle/framework/enforce.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -190,6 +191,46 @@ ssize_t product(const DDim& ddim) {
|
|
|
|
|
return boost::apply_visitor(visitor, ddim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct SliceVectorizeVisitor : public boost::static_visitor<> {
|
|
|
|
|
std::vector<int>& vector;
|
|
|
|
|
int begin;
|
|
|
|
|
int end;
|
|
|
|
|
|
|
|
|
|
SliceVectorizeVisitor(std::vector<int>& v, int b, int e)
|
|
|
|
|
: vector(v), begin(b), end(e) {
|
|
|
|
|
PADDLE_ENFORCE(begin < end,
|
|
|
|
|
"Begin index must be less than end index in ddim slice.");
|
|
|
|
|
PADDLE_ENFORCE(begin >= 0,
|
|
|
|
|
"Begin index can't be less than zero in ddim slice.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int S>
|
|
|
|
|
void operator()(const Dim<S>& dim) {
|
|
|
|
|
if (begin == 0) {
|
|
|
|
|
vector.push_back(dim.head);
|
|
|
|
|
} else {
|
|
|
|
|
--begin;
|
|
|
|
|
}
|
|
|
|
|
--end;
|
|
|
|
|
if (end > 0) {
|
|
|
|
|
this->operator()(dim.tail);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void operator()(const Dim<1>& dim) {
|
|
|
|
|
PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound.");
|
|
|
|
|
vector.push_back(dim.head);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DDim slice_ddim(const DDim& dim, int begin, int end) {
|
|
|
|
|
std::vector<int> vec;
|
|
|
|
|
vec.reserve(end - begin);
|
|
|
|
|
SliceVectorizeVisitor visitor(vec, begin, end);
|
|
|
|
|
boost::apply_visitor(visitor, dim);
|
|
|
|
|
return make_ddim(vec);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///\cond HIDDEN
|
|
|
|
|
|
|
|
|
|
struct ArityVisitor : boost::static_visitor<int> {
|
|
|
|
|