|
|
@ -17,13 +17,13 @@ struct Dim {
|
|
|
|
static constexpr int dimensions = i;
|
|
|
|
static constexpr int dimensions = i;
|
|
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
template <typename... Args>
|
|
|
|
HOSTDEVICE Dim(int _head, Args... _tail) : head(_head), tail(_tail...) {
|
|
|
|
HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
|
|
|
|
static_assert(sizeof...(_tail) == i - 1,
|
|
|
|
static_assert(sizeof...(_tail) == i - 1,
|
|
|
|
"Dim initialized with the wrong number of parameters");
|
|
|
|
"Dim initialized with the wrong number of parameters");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
Dim(int _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {}
|
|
|
|
Dim(int64_t _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {}
|
|
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
Dim() : head(0), tail() {}
|
|
|
|
Dim() : head(0), tail() {}
|
|
|
@ -31,12 +31,12 @@ struct Dim {
|
|
|
|
/** Construct a Dim from a linear index and size. Uses Fortran order
|
|
|
|
/** Construct a Dim from a linear index and size. Uses Fortran order
|
|
|
|
* indexing. */
|
|
|
|
* indexing. */
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
Dim(int idx, const Dim<i>& size)
|
|
|
|
Dim(int64_t idx, const Dim<i>& size)
|
|
|
|
: head(idx % size.head), tail(idx / size.head, size.tail) {}
|
|
|
|
: head(idx % size.head), tail(idx / size.head, size.tail) {}
|
|
|
|
|
|
|
|
|
|
|
|
/** Construct a Dim with each dimension set to the given index */
|
|
|
|
/** Construct a Dim with each dimension set to the given index */
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
Dim(int idx) : head(idx), tail(idx) {}
|
|
|
|
Dim(int64_t idx) : head(idx), tail(idx) {}
|
|
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
bool operator==(const Dim<i>& o) const {
|
|
|
|
bool operator==(const Dim<i>& o) const {
|
|
|
@ -47,13 +47,13 @@ struct Dim {
|
|
|
|
bool operator!=(const Dim<i>& o) const { return !(*this == o); }
|
|
|
|
bool operator!=(const Dim<i>& o) const { return !(*this == o); }
|
|
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
int& operator[](int idx);
|
|
|
|
int64_t& operator[](int idx);
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
int operator[](int idx) const;
|
|
|
|
int64_t operator[](int idx) const;
|
|
|
|
|
|
|
|
|
|
|
|
HOST std::string to_string() const;
|
|
|
|
HOST std::string to_string() const;
|
|
|
|
|
|
|
|
|
|
|
|
int head;
|
|
|
|
int64_t head;
|
|
|
|
Dim<i - 1> tail;
|
|
|
|
Dim<i - 1> tail;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -63,7 +63,7 @@ struct Dim<1> {
|
|
|
|
static constexpr int dimensions = 1;
|
|
|
|
static constexpr int dimensions = 1;
|
|
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
Dim(int _head) : head(_head) {}
|
|
|
|
Dim(int64_t _head) : head(_head) {}
|
|
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
Dim() : head(0) {}
|
|
|
|
Dim() : head(0) {}
|
|
|
@ -86,11 +86,11 @@ struct Dim<1> {
|
|
|
|
bool operator!=(const Dim<1>& o) const { return !(*this == o); }
|
|
|
|
bool operator!=(const Dim<1>& o) const { return !(*this == o); }
|
|
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
int& operator[](int idx);
|
|
|
|
int64_t& operator[](int idx);
|
|
|
|
HOSTDEVICE
|
|
|
|
HOSTDEVICE
|
|
|
|
int operator[](int idx) const;
|
|
|
|
int64_t operator[](int idx) const;
|
|
|
|
|
|
|
|
|
|
|
|
int head;
|
|
|
|
int64_t head;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
namespace {
|
|
|
@ -100,12 +100,12 @@ template <int i>
|
|
|
|
struct DimGetter {
|
|
|
|
struct DimGetter {
|
|
|
|
// Return a copy if Dim is const
|
|
|
|
// Return a copy if Dim is const
|
|
|
|
template <typename D>
|
|
|
|
template <typename D>
|
|
|
|
HOSTDEVICE static int impl(const D& d) {
|
|
|
|
HOSTDEVICE static int64_t impl(const D& d) {
|
|
|
|
return DimGetter<i - 1>::impl(d.tail);
|
|
|
|
return DimGetter<i - 1>::impl(d.tail);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Return a reference if Dim is mutable
|
|
|
|
// Return a reference if Dim is mutable
|
|
|
|
template <typename D>
|
|
|
|
template <typename D>
|
|
|
|
HOSTDEVICE static int& impl(D& d) {
|
|
|
|
HOSTDEVICE static int64_t& impl(D& d) {
|
|
|
|
return DimGetter<i - 1>::impl(d.tail);
|
|
|
|
return DimGetter<i - 1>::impl(d.tail);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
@ -115,18 +115,18 @@ template <>
|
|
|
|
struct DimGetter<0> {
|
|
|
|
struct DimGetter<0> {
|
|
|
|
// Return a copy if Dim is const
|
|
|
|
// Return a copy if Dim is const
|
|
|
|
template <typename D>
|
|
|
|
template <typename D>
|
|
|
|
HOSTDEVICE static int impl(const D& d) {
|
|
|
|
HOSTDEVICE static int64_t impl(const D& d) {
|
|
|
|
return d.head;
|
|
|
|
return d.head;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Return a reference if Dim is mutable
|
|
|
|
// Return a reference if Dim is mutable
|
|
|
|
template <typename D>
|
|
|
|
template <typename D>
|
|
|
|
HOSTDEVICE static int& impl(D& d) {
|
|
|
|
HOSTDEVICE static int64_t& impl(D& d) {
|
|
|
|
return d.head;
|
|
|
|
return d.head;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
template <int D>
|
|
|
|
HOSTDEVICE int& indexer(Dim<D>& dim, int idx) {
|
|
|
|
HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
if (idx < 0) {
|
|
|
|
if (idx < 0) {
|
|
|
|
throw std::invalid_argument("Tried to access a negative dimension");
|
|
|
|
throw std::invalid_argument("Tried to access a negative dimension");
|
|
|
@ -141,7 +141,7 @@ HOSTDEVICE int& indexer(Dim<D>& dim, int idx) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) {
|
|
|
|
HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) {
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
if (idx != 0) {
|
|
|
|
if (idx != 0) {
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
@ -153,7 +153,7 @@ HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
template <int D>
|
|
|
|
HOSTDEVICE int indexer(const Dim<D>& dim, int idx) {
|
|
|
|
HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
if (idx < 0) {
|
|
|
|
if (idx < 0) {
|
|
|
|
throw std::invalid_argument("Tried to access a negative dimension");
|
|
|
|
throw std::invalid_argument("Tried to access a negative dimension");
|
|
|
@ -168,7 +168,7 @@ HOSTDEVICE int indexer(const Dim<D>& dim, int idx) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) {
|
|
|
|
HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) {
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
if (idx != 0) {
|
|
|
|
if (idx != 0) {
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
@ -182,73 +182,76 @@ HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) {
|
|
|
|
} // namespace
|
|
|
|
} // namespace
|
|
|
|
// Static access to constant Dim
|
|
|
|
// Static access to constant Dim
|
|
|
|
template <int i, int l>
|
|
|
|
template <int i, int l>
|
|
|
|
HOSTDEVICE int get(const Dim<l>& d) {
|
|
|
|
HOSTDEVICE int64_t get(const Dim<l>& d) {
|
|
|
|
return DimGetter<i>::impl(d);
|
|
|
|
return DimGetter<i>::impl(d);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Static access to mutable Dim
|
|
|
|
// Static access to mutable Dim
|
|
|
|
template <int i, int l>
|
|
|
|
template <int i, int l>
|
|
|
|
HOSTDEVICE int& get(Dim<l>& d) {
|
|
|
|
HOSTDEVICE int64_t& get(Dim<l>& d) {
|
|
|
|
return DimGetter<i>::impl(d);
|
|
|
|
return DimGetter<i>::impl(d);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Dynamic access to constant Dim
|
|
|
|
// Dynamic access to constant Dim
|
|
|
|
template <int l>
|
|
|
|
template <int l>
|
|
|
|
HOSTDEVICE int Dim<l>::operator[](int i) const {
|
|
|
|
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
|
|
|
|
return indexer(*this, i);
|
|
|
|
return indexer(*this, i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Dynamic access to mutable Dim
|
|
|
|
// Dynamic access to mutable Dim
|
|
|
|
template <int l>
|
|
|
|
template <int l>
|
|
|
|
HOSTDEVICE int& Dim<l>::operator[](int i) {
|
|
|
|
HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
|
|
|
|
return indexer(*this, i);
|
|
|
|
return indexer(*this, i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Dynamic access to constant Dim
|
|
|
|
// Dynamic access to constant Dim
|
|
|
|
inline HOSTDEVICE int Dim<1>::operator[](int i) const {
|
|
|
|
inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const {
|
|
|
|
return indexer(*this, i);
|
|
|
|
return indexer(*this, i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Dynamic access to mutable Dim
|
|
|
|
// Dynamic access to mutable Dim
|
|
|
|
inline HOSTDEVICE int& Dim<1>::operator[](int i) { return indexer(*this, i); }
|
|
|
|
inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) {
|
|
|
|
|
|
|
|
return indexer(*this, i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Dynamic access to constant Dim
|
|
|
|
// Dynamic access to constant Dim
|
|
|
|
// without std::enable_if will try to instantiate this on get<0>(d)
|
|
|
|
// without std::enable_if will try to instantiate this on get<0>(d)
|
|
|
|
template <int l>
|
|
|
|
template <int l>
|
|
|
|
HOSTDEVICE typename std::enable_if<(l > 0), int>::type get(const Dim<l>& d,
|
|
|
|
HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l>& d,
|
|
|
|
int i) {
|
|
|
|
int i) {
|
|
|
|
return d[i];
|
|
|
|
return d[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Dynamic access to mutable Dim
|
|
|
|
// Dynamic access to mutable Dim
|
|
|
|
template <int l>
|
|
|
|
template <int l>
|
|
|
|
HOSTDEVICE typename std::enable_if<(l > 0), int&>::type get(Dim<l>& d, int i) {
|
|
|
|
HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim<l>& d,
|
|
|
|
|
|
|
|
int i) {
|
|
|
|
return d[i];
|
|
|
|
return d[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Dot product of two dims
|
|
|
|
// Dot product of two dims
|
|
|
|
template <int i>
|
|
|
|
template <int i>
|
|
|
|
HOSTDEVICE int linearize(const Dim<i>& a, const Dim<i>& b) {
|
|
|
|
HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
|
|
|
|
return a.head * b.head + linearize(a.tail, b.tail);
|
|
|
|
return a.head * b.head + linearize(a.tail, b.tail);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Base case dot product of two Dims
|
|
|
|
// Base case dot product of two Dims
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
HOSTDEVICE inline int linearize(const Dim<1>& a, const Dim<1>& b) {
|
|
|
|
HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) {
|
|
|
|
return a.head * b.head;
|
|
|
|
return a.head * b.head;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Product of a Dim
|
|
|
|
// Product of a Dim
|
|
|
|
template <int i>
|
|
|
|
template <int i>
|
|
|
|
HOSTDEVICE int product(const Dim<i>& a, int prod = 1) {
|
|
|
|
HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
|
|
|
|
return prod * a.head * product(a.tail);
|
|
|
|
return prod * a.head * product(a.tail);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Base case product of a Dim
|
|
|
|
// Base case product of a Dim
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
HOSTDEVICE inline int product(const Dim<1>& a, int prod) {
|
|
|
|
HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) {
|
|
|
|
return prod * a.head;
|
|
|
|
return prod * a.head;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|