add 1d f(2,3) support for 3x3 dw conv

pull/13138/head
lixian 4 years ago
parent dca301eabf
commit aec6dfd513

File diff suppressed because it is too large Load Diff

@ -47,12 +47,6 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig
bool CheckConvDwUse3X3(const ConvParameter *conv_param); bool CheckConvDwUse3X3(const ConvParameter *conv_param);
void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding);
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param); bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param);
void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
@ -74,6 +68,13 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const
size_t output_width, size_t input_stride, size_t relu, size_t relu6); size_t output_width, size_t input_stride, size_t relu, size_t relu6);
#endif #endif
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh);
bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num);
#endif
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel); int output_width, int input_stride, bool relu, bool relu6, int kernel);

@ -632,3 +632,23 @@ inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_st
_mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma); _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma);
} }
#endif #endif
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) {
// nchw to nc4hw4 with 1D F(2,3)
for (int i = 0; i < channel; i++) {
float *src_kernel = (float *)src + i * 9;
float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4;
for (int y = 0; y < 3; y++) {
float g0 = src_kernel[3 * y];
float g1 = src_kernel[3 * y + 1];
float g2 = src_kernel[3 * y + 2];
dst_kernel[16 * y] = g0;
dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2);
dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2);
dst_kernel[16 * y + 12] = g2;
}
}
}
#endif

@ -44,6 +44,10 @@ void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, i
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index); int block_index);
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel);
#endif
// Transpose 8X8 Fp32 block data // Transpose 8X8 Fp32 block data
typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride);
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64

@ -32,7 +32,6 @@
#define MS_ADDQ_EPI32 vaddq_s32 #define MS_ADDQ_EPI32 vaddq_s32
#define MS_MOVQ_F32 vmovq_n_f32 #define MS_MOVQ_F32 vmovq_n_f32
#define MS_MOVQ_EPI32 vmovq_n_s32 #define MS_MOVQ_EPI32 vmovq_n_s32
#define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32.
#define MS_SUBQ_F32 vsubq_f32 #define MS_SUBQ_F32 vsubq_f32
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
#define MS_STQ_F32 vst1q_f32 #define MS_STQ_F32 vst1q_f32
@ -76,7 +75,6 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_ADD256_EPI32 _mm256_add_epi32 #define MS_ADD256_EPI32 _mm256_add_epi32
#define MS_MOV256_F32 _mm256_set1_ps #define MS_MOV256_F32 _mm256_set1_ps
#define MS_MOV256_EPI32 _mm256_set1_epi32 #define MS_MOV256_EPI32 _mm256_set1_epi32
#define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32.
#define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3)) #define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3))
#define MS_ST256_F32 _mm256_storeu_ps #define MS_ST256_F32 _mm256_storeu_ps
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) #define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
@ -109,7 +107,6 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_ADDQ_EPI32 _mm_add_epi32 #define MS_ADDQ_EPI32 _mm_add_epi32
#define MS_MOVQ_F32 _mm_set1_ps #define MS_MOVQ_F32 _mm_set1_ps
#define MS_MOVQ_EPI32 _mm_set1_epi32 #define MS_MOVQ_EPI32 _mm_set1_epi32
#define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32.
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
#define MS_STQ_F32 _mm_storeu_ps #define MS_STQ_F32 _mm_storeu_ps
#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) #define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)

@ -21,6 +21,7 @@
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h" #include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@ -354,8 +355,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel = nullptr; kernel::LiteKernel *kernel = nullptr;
if (opParameter != nullptr && opParameter->infer_flag_) { if (opParameter != nullptr && opParameter->infer_flag_) {
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
if (CheckConvDw1DWinograd(conv_param, ctx->thread_num_)) {
kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx);
}
#endif
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) #if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
if (CheckConvDwUseIndirectBuffer(conv_param)) { if (kernel == nullptr && CheckConvDwUseIndirectBuffer(conv_param)) {
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx); kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx);
} }
#endif #endif
@ -367,7 +373,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
} }
return kernel; return kernel;
} } // namespace mindspore::kernel
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,

@ -18,8 +18,10 @@
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID; using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
namespace mindspore::kernel { namespace mindspore::kernel {
@ -28,10 +30,6 @@ ConvolutionDepthwise3x3CPUKernel::~ConvolutionDepthwise3x3CPUKernel() {
free(packed_weight_); free(packed_weight_);
packed_weight_ = nullptr; packed_weight_ = nullptr;
} }
if (sliding_ != nullptr) {
delete sliding_;
sliding_ = nullptr;
}
} }
int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() { int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
@ -39,22 +37,26 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
auto weight_tensor = in_tensors_[kWeightIndex]; auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData()); auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
int channel = weight_tensor->Batch(); int channel = weight_tensor->Batch();
int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width(); int c4 = UP_ROUND(channel, C4NUM);
int pack_weight_size = c4 * C12NUM;
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) { if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed."; packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
return RET_ERROR; if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
} }
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), channel); PackWeightConvDw3x3Fp32(origin_weight, packed_weight_, channel);
bias_data_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
if (bias_data_ == nullptr) { if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed."; bias_data_ = reinterpret_cast<float *>(malloc(c4 * sizeof(float)));
return RET_ERROR; if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
} }
memset(bias_data_, 0, c4 * sizeof(float));
memset(bias_data_, 0, channel * sizeof(float));
if (in_tensors_.size() == kInputSize2) { if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex]; auto bias_tensor = in_tensors_[kBiasIndex];
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData()); auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
@ -65,11 +67,6 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
} }
int ConvolutionDepthwise3x3CPUKernel::Init() { int ConvolutionDepthwise3x3CPUKernel::Init() {
sliding_ = new (std::nothrow) SlidingWindowParam;
if (sliding_ == nullptr) {
MS_LOG(ERROR) << "new sliding window param failed.";
return RET_ERROR;
}
auto ret = InitWeightBias(); auto ret = InitWeightBias();
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 InitWeightBias failed."; MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 InitWeightBias failed.";
@ -83,15 +80,19 @@ int ConvolutionDepthwise3x3CPUKernel::Init() {
int ConvolutionDepthwise3x3CPUKernel::ReSize() { int ConvolutionDepthwise3x3CPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init(); ConvolutionBaseCPUKernel::Init();
InitSlidingParamConvDw(sliding_, conv_param_, conv_param_->input_channel_);
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_); conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
return RET_OK; return RET_OK;
} }
int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) { int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) {
auto buffer = buffer_ + 64 * 10 * 10 * task_id; int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM);
auto buffer = buffer_ + C12NUM * c4 * units * task_id;
int step_oh = UP_DIV(conv_param_->output_h_, conv_param_->thread_num_);
int start_oh = step_oh * task_id;
int end_oh = MSMIN(start_oh + step_oh, conv_param_->output_h_);
ConvDw3x3(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, ConvDw3x3(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_,
sliding_, task_id); start_oh, end_oh);
return RET_OK; return RET_OK;
} }
@ -105,25 +106,18 @@ int ConvDw3x3Run(void *cdata, int task_id) {
return RET_OK; return RET_OK;
} }
int ConvolutionDepthwise3x3CPUKernel::InitBuffer() {
int buffer_size = 64 * 10 * 10 * conv_param_->thread_num_;
buffer_ = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size * sizeof(float)));
if (buffer_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::Run() { int ConvolutionDepthwise3x3CPUKernel::Run() {
auto ret = InitBuffer(); int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
if (ret != RET_OK) { int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM);
MS_LOG(ERROR) << "Depthwise int8 ReSize error!"; int buffer_size = units * c4 * C12NUM * conv_param_->thread_num_;
return ret; buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(buffer_size * sizeof(float)));
if (buffer_ == nullptr) {
MS_LOG(ERROR) << "ConvDw3x3Run failed to allocate buffer";
return RET_MEMORY_FAILED;
} }
if (IsTrain() && is_trainable()) { if (IsTrain() && is_trainable()) {
PackWeight(); InitWeightBias();
} }
auto input_tensor = in_tensors_.at(kInputIndex); auto input_tensor = in_tensors_.at(kInputIndex);
@ -132,32 +126,21 @@ int ConvolutionDepthwise3x3CPUKernel::Run() {
auto output_tensor = out_tensors_.at(kOutputIndex); auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c()); output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 || auto ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_);
sliding_->right_ < conv_param_->output_w_) { ctx_->allocator->Free(buffer_);
ConvDw3x3Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, sliding_);
}
ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
context_->allocator->Free(buffer_);
MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]"; MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;
} }
context_->allocator->Free(buffer_);
return RET_OK; return RET_OK;
} }
void ConvolutionDepthwise3x3CPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
}
int ConvolutionDepthwise3x3CPUKernel::Eval() { int ConvolutionDepthwise3x3CPUKernel::Eval() {
LiteKernel::Eval(); LiteKernel::Eval();
if (is_trainable()) { if (is_trainable()) {
PackWeight(); InitWeightBias();
} }
return RET_OK; return RET_OK;
} }
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h" #include "src/runtime/kernel/arm/base/convolution_base.h"
@ -39,14 +40,11 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel {
int Eval() override; int Eval() override;
private: private:
void PackWeight();
int InitBuffer();
SlidingWindowParam *sliding_ = nullptr;
float *packed_weight_ = nullptr; float *packed_weight_ = nullptr;
float *input_ptr_ = nullptr; float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr; float *output_ptr_ = nullptr;
float *buffer_ = nullptr; float *buffer_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_

Loading…
Cancel
Save