You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/framework/ddim.cc

225 lines
4.4 KiB

#include "paddle/framework/ddim.h"
namespace paddle {
namespace framework {
///@cond HIDDEN
template <int i>
Dim<i> make_dim(const int* d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1));
}
template <>
Dim<1> make_dim<1>(const int* d) {
return Dim<1>(*d);
}
void make_ddim(DDim& ddim, const int* dims, int n) {
switch (n) {
case 1:
ddim = make_dim<1>(dims);
break;
case 2:
ddim = make_dim<2>(dims);
break;
case 3:
ddim = make_dim<3>(dims);
break;
case 4:
ddim = make_dim<4>(dims);
break;
case 5:
ddim = make_dim<5>(dims);
break;
case 6:
ddim = make_dim<6>(dims);
break;
case 7:
ddim = make_dim<7>(dims);
break;
case 8:
ddim = make_dim<8>(dims);
break;
case 9:
ddim = make_dim<9>(dims);
break;
default:
throw std::invalid_argument(
"Dynamic dimensions must have between [1, 9] dimensions.");
}
}
///@endcond
DDim make_ddim(std::initializer_list<int> dims) {
DDim result(make_dim(0));
make_ddim(result, dims.begin(), dims.size());
return result;
}
DDim make_ddim(const std::vector<int>& dims) {
DDim result(make_dim(0));
make_ddim(result, &dims[0], dims.size());
return result;
}
///@cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes errors
class DynamicMutableIndexer : public boost::static_visitor<int&> {
public:
DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D>
int& operator()(Dim<D>& dim) const {
return dim[idx_];
}
private:
int idx_;
};
class DynamicConstIndexer : public boost::static_visitor<int> {
public:
DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D>
int operator()(const Dim<D>& dim) const {
return dim[idx_];
}
private:
int idx_;
};
///@endcond
int& DDim::operator[](int idx) {
return boost::apply_visitor(DynamicMutableIndexer(idx), var);
}
int DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var);
}
bool DDim::operator==(DDim d) const {
if (var.which() != d.getVar().which()) {
return false;
} else {
std::vector<int> v1 = vectorize(*this);
std::vector<int> v2 = vectorize(d);
for (unsigned int i = 0; i < v1.size(); i++) {
if (v1[i] != v2[i]) {
return false;
}
}
return true;
}
}
bool DDim::operator!=(DDim d) const { return !(*this == d); }
DDim DDim::operator+(DDim d) const {
std::vector<int> v1 = vectorize(*this);
std::vector<int> v2 = vectorize(d);
std::vector<int> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] + v2[i]);
}
return make_ddim(v3);
}
DDim DDim::operator*(DDim d) const {
std::vector<int> v1 = vectorize(*this);
std::vector<int> v2 = vectorize(d);
std::vector<int> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
}
return make_ddim(v3);
}
int get(const DDim& ddim, int idx) { return ddim[idx]; }
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }
///@cond HIDDEN
struct VectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector;
VectorizeVisitor(std::vector<int>& v) : vector(v) {}
template <typename T>
void operator()(const T& t) {
vector.push_back(t.head);
this->operator()(t.tail);
}
void operator()(const Dim<1>& t) { vector.push_back(t.head); }
};
///@endcond
std::vector<int> vectorize(const DDim& ddim) {
std::vector<int> result;
VectorizeVisitor visitor(result);
boost::apply_visitor(visitor, ddim);
return result;
}
ssize_t product(const DDim& ddim) {
ssize_t result = 1;
std::vector<int> v = vectorize(ddim);
for (auto i : v) {
result *= i;
}
return result;
}
///\cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> {
template <int D>
int operator()(Dim<D>) const {
return D;
}
};
///\endcond
int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); }
///\cond HIDDEN
struct DDimPrinter : boost::static_visitor<void> {
std::ostream& os;
DDimPrinter(std::ostream& os_) : os(os_) {}
template <typename T>
void operator()(const T& t) {
os << t;
}
};
///\endcond
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDimPrinter printer(os);
boost::apply_visitor(printer, ddim);
return os;
}
} // namespace framework
} // namespace paddle