!6239 add matrix transpose for fp16

Merge pull request !6239 from lixian/master
pull/6239/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0719657ae6

@ -217,19 +217,158 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
}
}
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
for (int hw = 0; hw < plane; hw++) {
int nhwc_index = n * channel * plane + hw * channel + c;
int nchw_index = n * channel * plane + c * plane + hw;
((float16_t *)(dst))[nhwc_index] = ((const float16_t *)(src))[nchw_index];
void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel) {
int hw16 = plane / C16NUM * C16NUM;
int c8 = channel / C8NUM * C8NUM;
int batch = plane * channel;
for (int n = 0; n < batches; n++) {
const float16_t *src_batch = (const float16_t *)src + n * batch;
float16_t *dst_batch = (float16_t *)dst + n * batch;
int hw = 0;
for (; hw < hw16; hw += C16NUM) {
int c = 0;
for (; c < c8; c += C8NUM) {
const float16_t *src_ptr = src_batch + hw * channel + c;
float16_t *dst_ptr = dst_batch + c * plane + hw;
#ifdef ENABLE_ARM64
size_t srcStride = channel * sizeof(float16_t);
size_t dstStride = plane * sizeof(float16_t);
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.8h}, [x10], %[srcStride]\n"
"ld1 {v1.8h}, [x10], %[srcStride]\n"
"ld1 {v2.8h}, [x10], %[srcStride]\n"
"ld1 {v3.8h}, [x10], %[srcStride]\n"
"ld1 {v4.8h}, [x10], %[srcStride]\n"
"ld1 {v5.8h}, [x10], %[srcStride]\n"
"ld1 {v6.8h}, [x10], %[srcStride]\n"
"ld1 {v7.8h}, [x10], %[srcStride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"ld1 {v8.8h}, [x10], %[srcStride]\n"
"ld1 {v9.8h}, [x10], %[srcStride]\n"
"ld1 {v10.8h}, [x10], %[srcStride]\n"
"ld1 {v11.8h}, [x10], %[srcStride]\n"
"ld1 {v12.8h}, [x10], %[srcStride]\n"
"ld1 {v13.8h}, [x10], %[srcStride]\n"
"ld1 {v14.8h}, [x10], %[srcStride]\n"
"ld1 {v15.8h}, [x10], %[srcStride]\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip1 v16.8h, v8.8h, v9.8h\n"
"zip1 v17.8h, v10.8h, v11.8h\n"
"zip1 v18.8h, v12.8h, v13.8h\n"
"zip1 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"add x10, x11, #16\n"
"st1 {v24.8h}, [x11], %[dstStride]\n"
"st1 {v28.8h}, [x10], %[dstStride]\n"
"st1 {v26.8h}, [x11], %[dstStride]\n"
"st1 {v30.8h}, [x10], %[dstStride]\n"
"st1 {v25.8h}, [x11], %[dstStride]\n"
"st1 {v29.8h}, [x10], %[dstStride]\n"
"st1 {v27.8h}, [x11], %[dstStride]\n"
"st1 {v31.8h}, [x10], %[dstStride]\n"
"zip2 v16.8h, v0.8h, v1.8h\n"
"zip2 v17.8h, v2.8h, v3.8h\n"
"zip2 v18.8h, v4.8h, v5.8h\n"
"zip2 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v16.8h, v8.8h, v9.8h\n"
"zip2 v17.8h, v10.8h, v11.8h\n"
"zip2 v18.8h, v12.8h, v13.8h\n"
"zip2 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], %[dstStride]\n"
"st1 {v28.8h}, [x10], %[dstStride]\n"
"st1 {v26.8h}, [x11], %[dstStride]\n"
"st1 {v30.8h}, [x10], %[dstStride]\n"
"st1 {v25.8h}, [x11], %[dstStride]\n"
"st1 {v29.8h}, [x10], %[dstStride]\n"
"st1 {v27.8h}, [x11], %[dstStride]\n"
"st1 {v31.8h}, [x10], %[dstStride]\n"
:
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#else
for (int tr = 0; tr < C16NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
}
}
#endif
}
for (; c < channel; c++) {
const float16_t *src_ptr = src_batch + hw * channel + c;
float16_t *dst_ptr = dst_batch + c * plane + hw;
for (size_t i = 0; i < C16NUM; i++) {
dst_ptr[i] = src_ptr[i * channel];
}
}
}
for (; hw < plane; hw++) {
const float16_t *src_ptr = src_batch + hw * channel;
float16_t *dst_ptr = dst_batch + hw;
for (size_t i = 0; i < channel; i++) {
dst_ptr[i * plane] = src_ptr[i];
}
}
}
return;
}
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
return PackNHWCToNCHWFp16(src, dst, batch, channel, plane);
}
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int ic4 = UP_DIV(channel, C4NUM);
int c4_channel = ic4 * C4NUM;

@ -43,6 +43,8 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel);

Loading…
Cancel
Save