!10012 [MS][LITE][Develop]avx fp32 matmul kernel support for deconv

From: @lx0095
Reviewed-by: @zhanghaibo5,@zhang_xue_tong,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong,@zhang_xue_tong
pull/10012/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 58cb834733

File diff suppressed because it is too large Load Diff

@ -78,3 +78,52 @@ void Relu6Fp32(float *data, float *dst, int ele_num) {
data[j] = data[j] > 6 ? 6 : data[j];
}
}
#ifdef ENABLE_AVX
#ifdef WIN32
void ReluFp32C8(float *data, float *dst, int ele_num) {
int four_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < four_block - 1; i++) {
int index = i * C8NUM;
data[index] = data[index] < 0 ? 0 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4];
data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5];
data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6];
data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7];
}
for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
}
}
void Relu6Fp32C8(float *data, float *dst, int ele_num) {
int four_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < four_block - 1; i++) {
int index = i * C8NUM;
data[index] = data[index] < 0 ? 0 : data[index];
data[index] = data[index] > 6 ? 6 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3];
data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4];
data[index + 4] = data[index + 4] > 6 ? 6 : data[index + 4];
data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5];
data[index + 5] = data[index + 5] > 6 ? 6 : data[index + 5];
data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6];
data[index + 6] = data[index + 6] > 6 ? 6 : data[index + 6];
data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7];
data[index + 7] = data[index + 7] > 6 ? 6 : data[index + 7];
}
for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
data[j] = data[j] > 6 ? 6 : data[j];
}
}
#endif
#endif

@ -31,6 +31,12 @@ int8_t MinInt8(int8_t a, int8_t b);
int8_t MaxInt8(int8_t a, int8_t b);
void ReluFp32(float *data, float *dst, int ele_num);
void Relu6Fp32(float *data, float *dst, int ele_num);
#ifdef ENABLE_AVX
#ifdef WIN32
void ReluFp32C8(float *data, float *dst, int ele_num);
void Relu6Fp32C8(float *data, float *dst, int ele_num);
#endif
#endif
int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3);
int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2);
int offset4d(const int *shape, const int *dims);

@ -681,6 +681,47 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
#endif
#ifdef ENABLE_AVX
#ifdef WIN32
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
do {
float *in[kernel];
for (int k = 0; k < kernel; k++) {
in[k] = input[k];
}
input = input + input_stride;
size_t c = channels;
const float *w = weights;
float *out = output;
memcpy(out, bias, channels * sizeof(float));
for (; c >= C8NUM; c -= C8NUM) {
for (int i = 0; i < C8NUM; i++) {
for (int k = 0; k < kernel; k++) {
out[i] += in[k][i] * w[i + k * C8NUM];
}
}
w += kernel * C8NUM;
out += C8NUM;
for (int k = 0; k < kernel; k++) {
in[k] += C8NUM;
}
}
for (int i = 0; i < c; i++) {
for (int k = 0; k < kernel; k++) {
out[i] += in[k][i] * w[i + k * C8NUM];
}
}
if (relu) {
ReluFp32C8(output, output, channels);
}
if (relu6) {
Relu6Fp32C8(output, output, channels);
}
output += channels;
} while (--output_width != 0);
}
#else
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
if (kernel == 9) {
@ -688,6 +729,7 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
}
}
#endif
#endif
void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
float *zero_ptr, const ConvParameter *conv_param, int task_id) {

@ -67,9 +67,11 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c
#endif
#ifdef ENABLE_AVX
#ifndef WIN32
void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, size_t input_stride, size_t relu, size_t relu6);
#endif
#endif
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel);

@ -147,7 +147,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
conv_param->input_channel_ = inputs[kInputIndex]->Channel();
conv_param->output_h_ = outputs[kOutputIndex]->Height();
conv_param->output_w_ = outputs[kOutputIndex]->Width();
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
#ifdef ENABLE_AVX
if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) {
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
#elif defined(ENABLE_ARM64)
if (CheckConvDwUseIndirectBuffer(conv_param)) {
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive);

Loading…
Cancel
Save