You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
225 lines
4.4 KiB
225 lines
4.4 KiB
#include "paddle/framework/ddim.h"
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
///@cond HIDDEN
|
|
|
|
template <int i>
|
|
Dim<i> make_dim(const int* d) {
|
|
return Dim<i>(*d, make_dim<i - 1>(d + 1));
|
|
}
|
|
|
|
template <>
|
|
Dim<1> make_dim<1>(const int* d) {
|
|
return Dim<1>(*d);
|
|
}
|
|
|
|
void make_ddim(DDim& ddim, const int* dims, int n) {
|
|
switch (n) {
|
|
case 1:
|
|
ddim = make_dim<1>(dims);
|
|
break;
|
|
case 2:
|
|
ddim = make_dim<2>(dims);
|
|
break;
|
|
case 3:
|
|
ddim = make_dim<3>(dims);
|
|
break;
|
|
case 4:
|
|
ddim = make_dim<4>(dims);
|
|
break;
|
|
case 5:
|
|
ddim = make_dim<5>(dims);
|
|
break;
|
|
case 6:
|
|
ddim = make_dim<6>(dims);
|
|
break;
|
|
case 7:
|
|
ddim = make_dim<7>(dims);
|
|
break;
|
|
case 8:
|
|
ddim = make_dim<8>(dims);
|
|
break;
|
|
case 9:
|
|
ddim = make_dim<9>(dims);
|
|
break;
|
|
default:
|
|
throw std::invalid_argument(
|
|
"Dynamic dimensions must have between [1, 9] dimensions.");
|
|
}
|
|
}
|
|
|
|
///@endcond
|
|
|
|
DDim make_ddim(std::initializer_list<int> dims) {
|
|
DDim result(make_dim(0));
|
|
make_ddim(result, dims.begin(), dims.size());
|
|
return result;
|
|
}
|
|
|
|
DDim make_ddim(const std::vector<int>& dims) {
|
|
DDim result(make_dim(0));
|
|
make_ddim(result, &dims[0], dims.size());
|
|
return result;
|
|
}
|
|
|
|
///@cond HIDDEN
|
|
// XXX For some reason, putting this in an anonymous namespace causes errors
|
|
class DynamicMutableIndexer : public boost::static_visitor<int&> {
|
|
public:
|
|
DynamicMutableIndexer(int idx) : idx_(idx) {}
|
|
|
|
template <int D>
|
|
int& operator()(Dim<D>& dim) const {
|
|
return dim[idx_];
|
|
}
|
|
|
|
private:
|
|
int idx_;
|
|
};
|
|
|
|
class DynamicConstIndexer : public boost::static_visitor<int> {
|
|
public:
|
|
DynamicConstIndexer(int idx) : idx_(idx) {}
|
|
|
|
template <int D>
|
|
int operator()(const Dim<D>& dim) const {
|
|
return dim[idx_];
|
|
}
|
|
|
|
private:
|
|
int idx_;
|
|
};
|
|
|
|
///@endcond
|
|
|
|
int& DDim::operator[](int idx) {
|
|
return boost::apply_visitor(DynamicMutableIndexer(idx), var);
|
|
}
|
|
|
|
int DDim::operator[](int idx) const {
|
|
return boost::apply_visitor(DynamicConstIndexer(idx), var);
|
|
}
|
|
|
|
bool DDim::operator==(DDim d) const {
|
|
if (var.which() != d.getVar().which()) {
|
|
return false;
|
|
} else {
|
|
std::vector<int> v1 = vectorize(*this);
|
|
std::vector<int> v2 = vectorize(d);
|
|
|
|
for (unsigned int i = 0; i < v1.size(); i++) {
|
|
if (v1[i] != v2[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
}
|
|
|
|
bool DDim::operator!=(DDim d) const { return !(*this == d); }
|
|
|
|
DDim DDim::operator+(DDim d) const {
|
|
std::vector<int> v1 = vectorize(*this);
|
|
std::vector<int> v2 = vectorize(d);
|
|
|
|
std::vector<int> v3;
|
|
|
|
assert(v1.size() == v2.size());
|
|
|
|
for (unsigned int i = 0; i < v1.size(); i++) {
|
|
v3.push_back(v1[i] + v2[i]);
|
|
}
|
|
|
|
return make_ddim(v3);
|
|
}
|
|
|
|
DDim DDim::operator*(DDim d) const {
|
|
std::vector<int> v1 = vectorize(*this);
|
|
std::vector<int> v2 = vectorize(d);
|
|
|
|
std::vector<int> v3;
|
|
|
|
assert(v1.size() == v2.size());
|
|
|
|
for (unsigned int i = 0; i < v1.size(); i++) {
|
|
v3.push_back(v1[i] * v2[i]);
|
|
}
|
|
|
|
return make_ddim(v3);
|
|
}
|
|
|
|
int get(const DDim& ddim, int idx) { return ddim[idx]; }
|
|
|
|
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }
|
|
|
|
///@cond HIDDEN
|
|
struct VectorizeVisitor : public boost::static_visitor<> {
|
|
std::vector<int>& vector;
|
|
|
|
VectorizeVisitor(std::vector<int>& v) : vector(v) {}
|
|
|
|
template <typename T>
|
|
void operator()(const T& t) {
|
|
vector.push_back(t.head);
|
|
this->operator()(t.tail);
|
|
}
|
|
|
|
void operator()(const Dim<1>& t) { vector.push_back(t.head); }
|
|
};
|
|
///@endcond
|
|
|
|
std::vector<int> vectorize(const DDim& ddim) {
|
|
std::vector<int> result;
|
|
VectorizeVisitor visitor(result);
|
|
boost::apply_visitor(visitor, ddim);
|
|
return result;
|
|
}
|
|
|
|
ssize_t product(const DDim& ddim) {
|
|
ssize_t result = 1;
|
|
std::vector<int> v = vectorize(ddim);
|
|
for (auto i : v) {
|
|
result *= i;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
///\cond HIDDEN
|
|
|
|
struct ArityVisitor : boost::static_visitor<int> {
|
|
template <int D>
|
|
int operator()(Dim<D>) const {
|
|
return D;
|
|
}
|
|
};
|
|
|
|
///\endcond
|
|
|
|
int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); }
|
|
|
|
///\cond HIDDEN
|
|
|
|
struct DDimPrinter : boost::static_visitor<void> {
|
|
std::ostream& os;
|
|
DDimPrinter(std::ostream& os_) : os(os_) {}
|
|
|
|
template <typename T>
|
|
void operator()(const T& t) {
|
|
os << t;
|
|
}
|
|
};
|
|
|
|
///\endcond
|
|
|
|
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
|
|
DDimPrinter printer(os);
|
|
boost::apply_visitor(printer, ddim);
|
|
return os;
|
|
}
|
|
|
|
} // namespace framework
|
|
} // namespace paddle
|