!7122 [MS][LITE][CPU] fp16 winograd init optimize

Merge pull request !7122 from liuzhongkai/winograd
pull/7122/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 66d415fcda

@ -0,0 +1,206 @@
#ifdef __aarch64__
.text
.align 5
.global MatrixMultiplyWinogradFp16
#ifndef __APPLE__
.type MatrixMultiplyWinogradFp16, %function
#endif
// MatrixMultiplyWinogradFp16(float16_t *matix_a, float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, int in_channel)
// x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel
MatrixMultiplyWinogradFp16:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
sub sp, sp, #48
st1 {v8.8h}, [sp], #16
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
mov x8, #2
mul x10, x5, x8 // n * 2
mov x17, x3 // m
mul x13, x6, x8 // in_channel * 2
mul x21, x13, x4 // in_channel * k * 2
LoopM:
mov x15, x5 // n
mov x14, x1 // mat_b
LoopN:
mov x16, x0 // mat_a_m
sub x18, x5, x15 // ni
sub x19, x17, x3 // mi
mul x18, x18, x17 // ni * m
mov x11, x6 // in_channel
add x18, x18, x19 // (ni * m) + mi
mul x18, x18, x13 // x18 * channel_in * 2
add x20, x2, x18 // dst + offset
cmp x11, #32
bge LoopC32
cmp x11, #16
bge LoopC16
cmp x11, #8
bge LoopC8
cmp x11, #4
bge LoopC4
cmp x11, #1
bge LoopC
b EndLoopC
LoopC32:
mov x12, x14
mov x9, x4 // new_k
dup v5.8h, wzr
dup v6.8h, wzr
dup v7.8h, wzr
dup v8.8h, wzr
LoopK32:
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x16], x13
ldr h4, [x12]
add x12, x12, x10
fmla v5.8h, v0.8h, v4.h[0]
fmla v6.8h, v1.8h, v4.h[0]
fmla v7.8h, v2.8h, v4.h[0]
fmla v8.8h, v3.8h, v4.h[0]
subs x9, x9, #1
bne LoopK32
Write32:
st1 {v5.8h}, [x20], #16
st1 {v6.8h}, [x20], #16
st1 {v7.8h}, [x20], #16
st1 {v8.8h}, [x20], #16
sub x16, x16, x21 // back x13 * k
add x16, x16, #64 // add 64B
subs x11, x11, #32
beq EndLoopC
cmp x11, #32
bge LoopC32
cmp x11, #16
bge LoopC16
cmp x11, #8
bge LoopC8
cmp x11, #4
bge LoopC4
cmp x11, #1
bge LoopC
LoopC16:
dup v5.8h, wzr
dup v6.8h, wzr
mov x9, x4 // new_k
mov x12, x14
LoopK16:
ld1 {v0.8h, v1.8h}, [x16], x13
ldr h4, [x12]
add x12, x12, x10
fmla v5.8h, v0.8h, v4.h[0]
fmla v6.8h, v1.8h, v4.h[0]
subs x9, x9, #1
bne LoopK16
Write16:
st1 {v5.8h}, [x20], #16
st1 {v6.8h}, [x20], #16
sub x16, x16, x21 // back x13 * k
add x16, x16, #32 // add 32B
subs x11, x11, #16
beq EndLoopC
cmp x11, #16
bge LoopC16
cmp x11, #8
bge LoopC8
cmp x11, #4
bge LoopC4
cmp x11, #1
bge LoopC
LoopC8:
dup v5.8h, wzr
mov x9, x4 // new_k
mov x12, x14
LoopK8:
ld1 {v0.8h}, [x16], x13
ldr h4, [x12]
add x12, x12, x10
fmla v5.8h, v0.8h, v4.h[0]
subs x9, x9, #1
bne LoopK8
Write8:
st1 {v5.8h}, [x20], #16
sub x16, x16, x21 // ptr back x13 * k
add x16, x16, #16 // add 16B
subs x11, x11, #8
beq EndLoopC
cmp x11, #8
bge LoopC8
cmp x11, #4
bge LoopC4
cmp x11, #1
bge LoopC
LoopC4:
dup v5.4h, wzr
mov x9, x4 // new_k
mov x12, x14
LoopK4:
ld1 {v0.4h}, [x16], x13
ldr h4, [x12]
add x12, x12, x10
fmla v5.4h, v0.4h, v4.h[0]
subs x9, x9, #1
bne LoopK4
Write4:
st1 {v5.4h}, [x20], #8
sub x16, x16, x21 // ptr back x13 * k
add x16, x16, #8 // add 8B
subs x11, x11, #4
beq EndLoopC
cmp x11, #4
bge LoopC4
cmp x11, #1
bge LoopC
LoopC:
dup v5.8h, wzr
mov x9, x4 // new_k
mov x12, x14
LoopK:
ldr h0, [x16]
add x16, x16, x13
ldr h4, [x12]
add x12, x12, x10
fmul h0, h0, h4
fadd h5, h5, h0
subs x9, x9, #1
bne LoopK
Write:
str h5, [x20], #2
sub x16, x16, x21 // ptr back x13 * k
add x16, x16, #2 // ptr add 2B
subs x11, x11, #1
beq EndLoopC
b LoopC
EndLoopC:
add x14, x14, #2
subs x15, x15, #1
beq EndLoopN
b LoopN
EndLoopN:
subs x3, x3, #1
beq EndLoopM
add x0, x0, x21
b LoopM
EndLoopM:
sub sp, sp, #48
st1 {v8.8h}, [sp], #16
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ret
#endif

@ -32,6 +32,24 @@ void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, fl
} }
} }
#ifndef ENABLE_ARM64
void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k,
int n, int in_channel) {
int cnt = 0;
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
for (int y = 0; y < in_channel; ++y) {
float16_t tmp = 0;
for (int z = 0; z < k; ++z) {
tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n];
}
matrix_c[cnt++] = tmp;
}
}
}
}
#endif
void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c,
const float16_t *bias, int m, int k, int n) { const float16_t *bias, int m, int k, int n) {
if (bias == NULL) { if (bias == NULL) {

@ -26,6 +26,8 @@ void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, fl
void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c,
const float16_t *bias, int m, int k, int n); const float16_t *bias, int m, int k, int n);
void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k,
int n, int in_channel);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -65,6 +65,14 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
} // tile num loop } // tile num loop
} }
void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel) {
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float16_t));
}
}
}
void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) {
// origin weight format : ohwi // origin weight format : ohwi
int input_channel = conv_param->input_channel_; int input_channel = conv_param->input_channel_;

@ -31,6 +31,8 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);
void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel);
void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);
void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel);

@ -36,91 +36,106 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel { namespace mindspore::kernel {
int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g,
float *matrix_gt, int oc_block) { float *matrix_gt, int oc_block) {
if (oc_block == 0) {
MS_LOG(ERROR) << "Divide by zero";
return RET_ERROR;
}
// original weight format : ohwi // original weight format : ohwi
auto channel_in = conv_param_->input_channel_; auto channel_in = conv_param_->input_channel_;
auto channel_out = conv_param_->output_channel_; auto channel_out = conv_param_->output_channel_;
int input_unit_square = input_unit_ * input_unit_;
int oc_block_num = UP_DIV(channel_out, oc_block); int oc_block_num = UP_DIV(channel_out, oc_block);
int block_stride = channel_in * oc_block;
int block_num_stride = block_stride * oc_block_num;
auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
if (matrix_g_data_fp16 == nullptr) {
MS_LOG(ERROR) << "malloc matrix_g_data_fp16 failed.";
return RET_ERROR;
}
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
if (matrix_gt_data_fp16 == nullptr) { if (matrix_gt_data_fp16 == nullptr) {
free(matrix_g_data_fp16);
MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed."; MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed.";
return RET_ERROR; return RET_ERROR;
} }
Float32ToFloat16(matrix_g, matrix_g_data_fp16, input_unit_ * kernel_unit_);
Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_); Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_);
// trans_filter = G*g*GT (g represents weight_data) // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T
// separate into two steps ===> tmp = G*g ===> out = tmp * GT // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T
auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float16_t))); auto tmp_data = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t)));
if (tmp_weight_data == nullptr) {
free(matrix_g_data_fp16);
free(matrix_gt_data_fp16);
MS_LOG(ERROR) << "malloc tmp_weight_data failed.";
return RET_ERROR;
}
auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
if (tmp_data == nullptr) { if (tmp_data == nullptr) {
free(tmp_weight_data);
free(matrix_g_data_fp16);
free(matrix_gt_data_fp16); free(matrix_gt_data_fp16);
MS_LOG(ERROR) << "malloc tmp_data failed."; MS_LOG(ERROR) << "malloc tmp_data failed.";
return RET_ERROR; return RET_ERROR;
} }
auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * input_unit_ * sizeof(float16_t))); auto trans_out_data =
reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t)));
if (trans_out_data == nullptr) { if (trans_out_data == nullptr) {
free(tmp_data); free(tmp_data);
free(tmp_weight_data);
free(matrix_g_data_fp16);
free(matrix_gt_data_fp16); free(matrix_gt_data_fp16);
MS_LOG(ERROR) << "malloc trans_out_data failed."; MS_LOG(ERROR) << "malloc trans_out_data failed.";
return RET_ERROR; return RET_ERROR;
} }
if (oc_block == 0) { #ifndef ENABLE_ARM64
MS_LOG(ERROR) << "Divide by zero"; auto tmp_data1 = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t)));
free(tmp_weight_data); if (tmp_data1 == nullptr) {
free(tmp_data); free(tmp_data);
free(matrix_gt_data_fp16);
free(trans_out_data); free(trans_out_data);
free(matrix_g_data_fp16); MS_LOG(ERROR) << "malloc tmp_data1 failed.";
return RET_ERROR;
}
auto trans_out_data1 =
reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t)));
if (trans_out_data1 == nullptr) {
free(tmp_data);
free(tmp_data1);
free(matrix_gt_data_fp16); free(matrix_gt_data_fp16);
free(trans_out_data);
MS_LOG(ERROR) << "malloc trans_out_data1 failed.";
return RET_ERROR; return RET_ERROR;
} }
int stride1 = channel_in * oc_block; #endif
int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in;
for (int i = 0; i < channel_out; i++) { for (int i = 0; i < channel_out; i++) {
int out_c_block = i / oc_block; int out_c_block = i / oc_block;
int out_c_res = i % oc_block; int out_c_res = i % oc_block;
int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; int output_oz_offset = out_c_block * block_stride + out_c_res;
int output_oz_offset = out_c_block * stride1 + out_c_res;
for (int j = 0; j < channel_in; j++) { #ifndef ENABLE_ARM64
int input_iz_offset = input_oz_offset + j; // tmp_data = g * GT
int output_iz_offset = output_oz_offset + j * oc_block; MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_,
for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { kernel_unit_, input_unit_, channel_in);
int input_xy_offset = input_iz_offset + k * channel_in; // tmp_data1 = (tmp_data)T
tmp_weight_data[k] = *(weight_data + input_xy_offset); PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in);
} // trans_out_data1 = tmp * GT
// now we only support row-major matrix-multiply MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit_, kernel_unit_, input_unit_,
// tmp = G * g channel_in);
MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_); // trans_out_data = (trans_out_data1)T
// out = tmp * GT PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in);
MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); #else
// tmp = (g * GT)T
for (int z = 0; z < input_unit_square; z++) { MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_,
int output_xy_offset = output_iz_offset + z * oc_block_num * stride1; kernel_unit_, input_unit_, channel_in);
trans_weight_[output_xy_offset] = trans_out_data[z]; // trans = (tmp * GT)T
MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_,
channel_in);
#endif
int in_offset = 0;
for (int j = 0; j < input_unit_; ++j) {
for (int k = 0; k < input_unit_; ++k) {
for (int c = 0; c < channel_in; ++c) {
*(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
}
in_offset += channel_in;
output_oz_offset += block_num_stride;
} }
} }
} }
free(tmp_weight_data);
#ifndef ENABLE_ARM64
free(tmp_data1);
free(trans_out_data1);
#endif
free(tmp_data); free(tmp_data);
free(trans_out_data); free(trans_out_data);
free(matrix_g_data_fp16);
free(matrix_gt_data_fp16); free(matrix_gt_data_fp16);
return RET_OK; return RET_OK;
} }

Loading…
Cancel
Save