add slice_dim draft

cblas_new
fengjiayi 8 years ago
parent 8bcd1faffc
commit ee90c2d22b

@ -401,5 +401,20 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
return result;
}
template <int D, int S>
Dim<D> slice(const Dim<S>& dim, int begin, int end) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in Dim slice.");
PADDLE_ENFORCE(begin >= 0 && end <= S && end - begin == D,
"Index error occurs in Dim slice.");
if (begin > 0) {
return slice<D>(dim.tail, begin - 1, end - 1);
}
if (D == 1) {
return Dim<1>(dim.head);
}
return Dim<D>(dim.head, slice<D - 1>(dim.tail, 0, end - 1));
}
} // namespace framework
} // namespace paddle

Loading…
Cancel
Save