|
|
@ -401,5 +401,20 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
|
|
|
|
return result;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int D, int S>
|
|
|
|
|
|
|
|
Dim<D> slice(const Dim<S>& dim, int begin, int end) {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(begin < end,
|
|
|
|
|
|
|
|
"Begin index must be less than end index in Dim slice.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(begin >= 0 && end <= S && end - begin == D,
|
|
|
|
|
|
|
|
"Index error occurs in Dim slice.");
|
|
|
|
|
|
|
|
if (begin > 0) {
|
|
|
|
|
|
|
|
return slice<D>(dim.tail, begin - 1, end - 1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (D == 1) {
|
|
|
|
|
|
|
|
return Dim<1>(dim.head);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return Dim<D>(dim.head, slice<D - 1>(dim.tail, 0, end - 1));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|