|
|
|
@ -38,13 +38,14 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
auto width = dst_dims[1];
|
|
|
|
|
auto* src_data = src.data<T>();
|
|
|
|
|
auto* dst_data = dst->data<T>();
|
|
|
|
|
for (int i = 0; i < height; ++i) {
|
|
|
|
|
if (is_src_index) {
|
|
|
|
|
memcpy(dst_data + i * width, src_data + index[i] * width,
|
|
|
|
|
width * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
memcpy(dst_data + index[i] * width, src_data + i * width,
|
|
|
|
|
width * sizeof(T));
|
|
|
|
|
const int sz = width * sizeof(T);
|
|
|
|
|
if (is_src_index) {
|
|
|
|
|
for (int i = 0; i < height; ++i) {
|
|
|
|
|
memcpy(dst_data + i * width, src_data + index[i] * width, sz);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < height; ++i) {
|
|
|
|
|
memcpy(dst_data + index[i] * width, src_data + i * width, sz);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|