|
|
|
@ -72,38 +72,36 @@ struct Dim {
|
|
|
|
|
|
|
|
|
|
// Base case specialization
|
|
|
|
|
template <>
|
|
|
|
|
struct Dim<1> {
|
|
|
|
|
static constexpr int dimensions = 1;
|
|
|
|
|
struct Dim<0> {
|
|
|
|
|
static constexpr int dimensions = 0;
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
|
Dim(int64_t _head) : head(_head) {}
|
|
|
|
|
Dim(int64_t _head) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
|
Dim() : head(0) {}
|
|
|
|
|
Dim() {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
|
Dim(int idx, const Dim<1>& size) : head(idx) {
|
|
|
|
|
Dim(int idx, const Dim<0>& size) {
|
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
|
if (idx >= size.head) {
|
|
|
|
|
if (idx > 0) {
|
|
|
|
|
throw std::invalid_argument("Index out of range.");
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ASSERT(idx < size.head);
|
|
|
|
|
PADDLE_ASSERT(idx == 0);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
|
bool operator==(const Dim<1>& o) const { return (head == o.head); }
|
|
|
|
|
bool operator==(const Dim<0>& o) const { return true; }
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
|
bool operator!=(const Dim<1>& o) const { return !(*this == o); }
|
|
|
|
|
bool operator!=(const Dim<0>& o) const { return false; }
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE
|
|
|
|
|
int64_t& operator[](int idx);
|
|
|
|
|
HOSTDEVICE
|
|
|
|
|
int64_t operator[](int idx) const;
|
|
|
|
|
|
|
|
|
|
int64_t head;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) {
|
|
|
|
|
HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
|
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
|
if (idx != 0) {
|
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
|
|
}
|
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ASSERT(idx == 0);
|
|
|
|
|
PADDLE_ASSERT(false);
|
|
|
|
|
#endif
|
|
|
|
|
return dim.head;
|
|
|
|
|
static int64_t head = 0;
|
|
|
|
|
return head;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) {
|
|
|
|
|
HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
|
|
|
|
|
#ifndef __CUDA_ARCH__
|
|
|
|
|
if (idx != 0) {
|
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
|
|
}
|
|
|
|
|
throw std::invalid_argument("Invalid index");
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ASSERT(idx == 0);
|
|
|
|
|
PADDLE_ASSERT(false);
|
|
|
|
|
#endif
|
|
|
|
|
return dim.head;
|
|
|
|
|
static int64_t head = 0;
|
|
|
|
|
return head;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Dynamic access to constant Dim
|
|
|
|
|
inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const {
|
|
|
|
|
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
|
|
|
|
|
return indexer(*this, i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Dynamic access to mutable Dim
|
|
|
|
|
inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) {
|
|
|
|
|
inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
|
|
|
|
|
return indexer(*this, i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
|
|
|
|
|
// Base case dot product of two Dims
|
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) {
|
|
|
|
|
return a.head * b.head;
|
|
|
|
|
HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Product of a Dim
|
|
|
|
@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
|
|
|
|
|
// Base case product of a Dim
|
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) {
|
|
|
|
|
return prod * a.head;
|
|
|
|
|
HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
|
|
|
|
|
return prod;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Is 0 <= idx_i < size_i for all i?
|
|
|
|
@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
|
|
|
|
|
// Base case of is 0 <= idx_i < size_i ?
|
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
|
|
|
|
|
return ((0 <= idx.head) && (idx.head < size.head));
|
|
|
|
|
HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
|
|
|
|
|
// Base case of ex_prefix_mul
|
|
|
|
|
// Notice it is inline because it is no longer a template
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
|
|
|
|
|
return Dim<1>(mul);
|
|
|
|
|
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
|
|
|
|
|
return Dim<0>();
|
|
|
|
|
}
|
|
|
|
|
///\endcond
|
|
|
|
|
|
|
|
|
@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
|
|
|
|
|
|
|
|
|
|
// Base case
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) {
|
|
|
|
|
return Dim<1>(a.head + b.head);
|
|
|
|
|
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
|
|
|
|
|
return Dim<0>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int i>
|
|
|
|
@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
|
|
|
|
|
|
|
|
|
|
// Base case
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) {
|
|
|
|
|
return Dim<1>(a.head * b.head);
|
|
|
|
|
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
|
|
|
|
|
return Dim<0>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int i>
|
|
|
|
@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
|
|
|
|
|
///\cond HIDDEN
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size,
|
|
|
|
|
const Dim<1>& stride) {
|
|
|
|
|
int norm_stride = size.head == 1 ? 0 : stride.head;
|
|
|
|
|
return Dim<1>(norm_stride);
|
|
|
|
|
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
|
|
|
|
|
const Dim<0>& stride) {
|
|
|
|
|
return Dim<0>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///\endcond
|
|
|
|
@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
|
|
|
|
|
return os;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
|
|
|
|
|
return os;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int i>
|
|
|
|
|
HOST std::string Dim<i>::to_string() const {
|
|
|
|
|
std::stringstream stream;
|
|
|
|
|