|
|
|
@ -328,36 +328,36 @@ void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4) {
|
|
|
|
|
float dst01 = (local_ptr + 4)[0];
|
|
|
|
|
float dst02 = (local_ptr + 8)[0];
|
|
|
|
|
|
|
|
|
|
float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
|
|
|
|
|
float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
|
|
|
|
|
float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
|
|
|
|
|
const float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
|
|
|
|
|
const float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
|
|
|
|
|
const float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
|
|
|
|
|
|
|
|
|
|
float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
|
|
|
|
|
float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
|
|
|
|
|
float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
|
|
|
|
|
const float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
|
|
|
|
|
const float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
|
|
|
|
|
const float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
|
|
|
|
|
|
|
|
|
|
float dst30 = (local_ptr + 24)[0];
|
|
|
|
|
float dst31 = (local_ptr + 28)[0];
|
|
|
|
|
float dst32 = (local_ptr + 32)[0];
|
|
|
|
|
|
|
|
|
|
float m00 = dst00;
|
|
|
|
|
float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02;
|
|
|
|
|
float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02;
|
|
|
|
|
const float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02;
|
|
|
|
|
const float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02;
|
|
|
|
|
float m03 = dst02;
|
|
|
|
|
|
|
|
|
|
float m10 = dst10;
|
|
|
|
|
float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12;
|
|
|
|
|
float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12;
|
|
|
|
|
const float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12;
|
|
|
|
|
const float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12;
|
|
|
|
|
float m13 = dst12;
|
|
|
|
|
|
|
|
|
|
float m20 = dst20;
|
|
|
|
|
float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22;
|
|
|
|
|
float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22;
|
|
|
|
|
const float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22;
|
|
|
|
|
const float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22;
|
|
|
|
|
float m23 = dst22;
|
|
|
|
|
|
|
|
|
|
float m30 = dst30;
|
|
|
|
|
float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32;
|
|
|
|
|
float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32;
|
|
|
|
|
const float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32;
|
|
|
|
|
const float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32;
|
|
|
|
|
float m33 = dst32;
|
|
|
|
|
|
|
|
|
|
*(dst + j) = m00;
|
|
|
|
@ -387,7 +387,7 @@ void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4) {
|
|
|
|
|
void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float *block_buffer, int out_h_block,
|
|
|
|
|
int out_w_block, const ConvParameter *conv_param) {
|
|
|
|
|
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
|
|
|
|
|
int input_unit = 4;
|
|
|
|
|
const int input_unit = 4;
|
|
|
|
|
memset(trans_input, 0, out_h_block * out_h_block * 16 * C4NUM * sizeof(float));
|
|
|
|
|
|
|
|
|
|
for (int oh = 0; oh < out_h_block; oh++) {
|
|
|
|
@ -426,7 +426,7 @@ void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float
|
|
|
|
|
|
|
|
|
|
// todo yangruoqi: implement assembly
|
|
|
|
|
void ConvDw3x3Fp32Winograd(float *trans_buffer, const float *weight, int out_h_block, int out_w_block) {
|
|
|
|
|
int unit = 4;
|
|
|
|
|
const int unit = 4;
|
|
|
|
|
for (int oh = 0; oh < out_h_block; oh++) {
|
|
|
|
|
float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM;
|
|
|
|
|
for (int ow = 0; ow < out_w_block; ow++) {
|
|
|
|
@ -583,7 +583,7 @@ void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const flo
|
|
|
|
|
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
|
|
|
|
|
bool h_in_range = true;
|
|
|
|
|
for (int oh = 0; oh < out_h_block; oh++) {
|
|
|
|
|
int real_oh = 2 * oh;
|
|
|
|
|
const int real_oh = 2 * oh;
|
|
|
|
|
if ((oh + 1) * 2 > conv_param->output_h_) {
|
|
|
|
|
h_in_range = false;
|
|
|
|
|
}
|
|
|
|
@ -592,7 +592,7 @@ void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const flo
|
|
|
|
|
float *output_oh = output_data + real_oh * conv_param->output_w_ * oc4 * C4NUM;
|
|
|
|
|
|
|
|
|
|
for (int ow = 0; ow < out_w_block; ow++) {
|
|
|
|
|
int real_ow = 2 * ow;
|
|
|
|
|
const int real_ow = 2 * ow;
|
|
|
|
|
if ((ow + 1) * 2 > conv_param->output_w_) {
|
|
|
|
|
w_in_range = false;
|
|
|
|
|
}
|
|
|
|
|