|
|
@ -18,32 +18,10 @@ limitations under the License. */
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace framework {
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
struct DDimAssignFunctor {
|
|
|
|
|
|
|
|
static_assert(std::is_integral<T>::value, "T must be integral type");
|
|
|
|
|
|
|
|
using result_type = void;
|
|
|
|
|
|
|
|
explicit DDimAssignFunctor(const T* in) : in_(in) {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
|
|
|
|
inline void operator()(Dim<D>& dim) { // NOLINT
|
|
|
|
|
|
|
|
UnrollAssign<D>::Run(in_, dim.data());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const T* in_;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DDim::DDim(const int* d, int n) : rank_(n) {
|
|
|
|
|
|
|
|
this->apply_visitor(DDimAssignFunctor<int>(d));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DDim::DDim(const int64_t* d, int n) : rank_(n) {
|
|
|
|
|
|
|
|
this->apply_visitor(DDimAssignFunctor<int64_t>(d));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int N>
|
|
|
|
template <int N>
|
|
|
|
Dim<N> make_dim(const int64_t* d) {
|
|
|
|
Dim<N> make_dim(const int64_t* d) {
|
|
|
|
Dim<N> ret;
|
|
|
|
Dim<N> ret;
|
|
|
|
for (int i = 0; i < N; ++i) ret[i] = d[i];
|
|
|
|
fix_dim_assign(d, ret.GetMutable());
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -64,14 +42,14 @@ struct DDimEqualityVisitor {
|
|
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
template <int D>
|
|
|
|
inline bool operator()(const Dim<D>& self) const {
|
|
|
|
inline bool operator()(const Dim<D>& self) const {
|
|
|
|
return UnrollCompare<D>::Run(self.data(), d_);
|
|
|
|
return UnrollCompare<D>::Run(self.Get(), d_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t* d_;
|
|
|
|
const int64_t* d_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
bool DDim::operator==(const DDim& d) const {
|
|
|
|
bool DDim::operator==(const DDim& d) const {
|
|
|
|
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.data()));
|
|
|
|
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
|
|
|
|
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
|
|
|
@ -82,7 +60,7 @@ struct DDimPlusVisitor {
|
|
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
template <int D>
|
|
|
|
inline void operator()(Dim<D>& self) const {
|
|
|
|
inline void operator()(Dim<D>& self) const {
|
|
|
|
UnrollAdd<D>::Run(d1_, d2_, self.data());
|
|
|
|
UnrollAdd<D>::Run(d1_, d2_, self.GetMutable());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t* d1_;
|
|
|
|
const int64_t* d1_;
|
|
|
@ -93,7 +71,7 @@ DDim DDim::operator+(const DDim& d) const {
|
|
|
|
PADDLE_ENFORCE(rank_ == d.rank_);
|
|
|
|
PADDLE_ENFORCE(rank_ == d.rank_);
|
|
|
|
DDim ret;
|
|
|
|
DDim ret;
|
|
|
|
ret.rank_ = rank_;
|
|
|
|
ret.rank_ = rank_;
|
|
|
|
ret.apply_visitor(DDimPlusVisitor(data(), d.data()));
|
|
|
|
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -103,7 +81,7 @@ struct DDimMulVisitor {
|
|
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
template <int D>
|
|
|
|
inline void operator()(Dim<D>& self) const {
|
|
|
|
inline void operator()(Dim<D>& self) const {
|
|
|
|
UnrollMul<D>::Run(d1_, d2_, self.data());
|
|
|
|
UnrollMul<D>::Run(d1_, d2_, self.GetMutable());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t* d1_;
|
|
|
|
const int64_t* d1_;
|
|
|
@ -114,7 +92,7 @@ DDim DDim::operator*(const DDim& d) const {
|
|
|
|
PADDLE_ENFORCE(rank_ == d.rank_);
|
|
|
|
PADDLE_ENFORCE(rank_ == d.rank_);
|
|
|
|
DDim ret;
|
|
|
|
DDim ret;
|
|
|
|
ret.rank_ = rank_;
|
|
|
|
ret.rank_ = rank_;
|
|
|
|
ret.apply_visitor(DDimMulVisitor(data(), d.data()));
|
|
|
|
ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -124,9 +102,7 @@ void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> vectorize(const DDim& ddim) {
|
|
|
|
std::vector<int64_t> vectorize(const DDim& ddim) {
|
|
|
|
std::vector<int64_t> result(DDim::kMaxRank);
|
|
|
|
std::vector<int64_t> result(DDim::kMaxRank);
|
|
|
|
for (int i = 0; i < ddim.size(); ++i) {
|
|
|
|
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
|
|
|
|
result[i] = ddim[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
result.resize(ddim.size());
|
|
|
|
result.resize(ddim.size());
|
|
|
|
return result;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -135,9 +111,7 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
|
|
|
|
// which does not fit cudnn inputs.
|
|
|
|
// which does not fit cudnn inputs.
|
|
|
|
std::vector<int> vectorize2int(const DDim& ddim) {
|
|
|
|
std::vector<int> vectorize2int(const DDim& ddim) {
|
|
|
|
std::vector<int> result(DDim::kMaxRank);
|
|
|
|
std::vector<int> result(DDim::kMaxRank);
|
|
|
|
for (int i = 0; i < ddim.size(); ++i) {
|
|
|
|
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
|
|
|
|
result[i] = ddim[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
result.resize(ddim.size());
|
|
|
|
result.resize(ddim.size());
|
|
|
|
return result;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -154,15 +128,11 @@ int64_t product(const DDim& ddim) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
DDim slice_ddim(const DDim& dim, int begin, int end) {
|
|
|
|
DDim slice_ddim(const DDim& dim, int begin, int end) {
|
|
|
|
PADDLE_ENFORCE(begin < end,
|
|
|
|
|
|
|
|
"Begin index must be less than end index in ddim slice.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(begin >= 0,
|
|
|
|
PADDLE_ENFORCE(begin >= 0,
|
|
|
|
"Begin index can't be less than zero in ddim slice.");
|
|
|
|
"Begin index can't be less than zero in ddim slice.");
|
|
|
|
DDim ret;
|
|
|
|
DDim ret;
|
|
|
|
ret.rank_ = end - begin;
|
|
|
|
ret.rank_ = end - begin;
|
|
|
|
for (int i = 0; i < ret.rank_; ++i) {
|
|
|
|
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
|
|
|
|
ret[i] = dim[i + begin];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|