parent
0c7ba7a7fa
commit
415539655d
@ -0,0 +1,146 @@
|
||||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDwFp32Indirect3x3
|
||||
#ifndef __APPLE__
|
||||
.type ConvDwFp32Indirect3x3, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width,
|
||||
// size_t input_stride, size_t relu, size_t relu6)
|
||||
// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6
|
||||
|
||||
ConvDwFp32Indirect3x3:
|
||||
sub sp, sp, #16
|
||||
stp x19, x20, [sp], #16
|
||||
|
||||
movi v31.4s, #6
|
||||
scvtf v31.4s, v31.4s
|
||||
dup v30.4s, wzr
|
||||
|
||||
ldr x8, [sp]
|
||||
cmp x5, #0
|
||||
beq End
|
||||
|
||||
LoopPixel:
|
||||
ldp x12, x13, [x1]
|
||||
ldp x14, x15, [x1, #16]
|
||||
ldp x16, x17, [x1, #32]
|
||||
ldp x18, x19, [x1, #48]
|
||||
ldr x20, [x1, #64]
|
||||
mov x9, x2
|
||||
mov x10, x3
|
||||
mov x11, x4
|
||||
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
ld1 {v1.4s}, [x13], #16
|
||||
ld1 {v2.4s}, [x14], #16
|
||||
|
||||
ld1 {v17.4s}, [x9], #16
|
||||
ld1 {v18.4s}, [x9], #16
|
||||
ld1 {v19.4s}, [x9], #16
|
||||
|
||||
ld1 {v29.4s}, [x10], #16
|
||||
cmp x11, #4
|
||||
ble LeftLoop
|
||||
LoopC4:
|
||||
fmla v29.4s, v0.4s, v17.4s
|
||||
ld1 {v3.4s}, [x15], #16
|
||||
ld1 {v20.4s}, [x9], #16
|
||||
fmla v29.4s, v1.4s, v18.4s
|
||||
ld1 {v4.4s}, [x16], #16
|
||||
ld1 {v21.4s}, [x9], #16
|
||||
fmla v29.4s, v2.4s, v19.4s
|
||||
ld1 {v5.4s}, [x17], #16
|
||||
ld1 {v22.4s}, [x9], #16
|
||||
fmla v29.4s, v3.4s, v20.4s
|
||||
ld1 {v6.4s}, [x18], #16
|
||||
ld1 {v23.4s}, [x9], #16
|
||||
fmla v29.4s, v4.4s, v21.4s
|
||||
ld1 {v7.4s}, [x19], #16
|
||||
ld1 {v24.4s}, [x9], #16
|
||||
fmla v29.4s, v5.4s, v22.4s
|
||||
ld1 {v16.4s}, [x20], #16
|
||||
ld1 {v25.4s}, [x9], #16
|
||||
fmla v29.4s, v6.4s, v23.4s
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
ld1 {v17.4s}, [x9], #16
|
||||
fmla v29.4s, v7.4s, v24.4s
|
||||
ld1 {v1.4s}, [x13], #16
|
||||
ld1 {v18.4s}, [x9], #16
|
||||
fmla v29.4s, v16.4s, v25.4s
|
||||
ld1 {v2.4s}, [x14], #16
|
||||
ld1 {v19.4s}, [x9], #16
|
||||
|
||||
cbnz x8, Relu6
|
||||
cbnz x7, Relu
|
||||
b Write
|
||||
Relu6:
|
||||
fmin v29.4s, v29.4s, v31.4s
|
||||
Relu:
|
||||
fmax v29.4s, v29.4s, v30.4s
|
||||
Write:
|
||||
st1 {v29.4s}, [x0], #16
|
||||
|
||||
ld1 {v29.4s}, [x10], #16
|
||||
sub x11, x11, #4
|
||||
cmp x11, #4
|
||||
bgt LoopC4
|
||||
|
||||
LeftLoop:
|
||||
fmla v29.4s, v0.4s, v17.4s
|
||||
ld1 {v3.4s}, [x15], #16
|
||||
ld1 {v20.4s}, [x9], #16
|
||||
fmla v29.4s, v1.4s, v18.4s
|
||||
ld1 {v4.4s}, [x16], #16
|
||||
ld1 {v21.4s}, [x9], #16
|
||||
fmla v29.4s, v2.4s, v19.4s
|
||||
ld1 {v5.4s}, [x17], #16
|
||||
ld1 {v22.4s}, [x9], #16
|
||||
fmla v29.4s, v3.4s, v20.4s
|
||||
ld1 {v6.4s}, [x18], #16
|
||||
ld1 {v23.4s}, [x9], #16
|
||||
fmla v29.4s, v4.4s, v21.4s
|
||||
ld1 {v7.4s}, [x19], #16
|
||||
ld1 {v24.4s}, [x9], #16
|
||||
fmla v29.4s, v5.4s, v22.4s
|
||||
ld1 {v16.4s}, [x20], #16
|
||||
ld1 {v25.4s}, [x9], #16
|
||||
fmla v29.4s, v6.4s, v23.4s
|
||||
fmla v29.4s, v7.4s, v24.4s
|
||||
fmla v29.4s, v16.4s, v25.4s
|
||||
|
||||
cbnz x8, LeftRelu6
|
||||
cbnz x7, LeftRelu
|
||||
b LeftWrite
|
||||
LeftRelu6:
|
||||
fmin v29.4s, v29.4s, v31.4s
|
||||
LeftRelu:
|
||||
fmax v29.4s, v29.4s, v30.4s
|
||||
LeftWrite:
|
||||
cmp x11, #4
|
||||
bne Write3
|
||||
st1 {v29.4s}, [x0], #16
|
||||
b NextPixel
|
||||
Write3:
|
||||
sxtw x11, w11
|
||||
tbnz w11, #1, Write2
|
||||
tbnz w11, #0, Write1
|
||||
Write2:
|
||||
str d29, [x0], #8
|
||||
ext v29.16b, v29.16b, v29.16b, #8
|
||||
tbz w11, #0, NextPixel
|
||||
Write1:
|
||||
str s29, [x0], #4
|
||||
|
||||
NextPixel:
|
||||
add x1, x1, x6
|
||||
sub x5, x5, #1
|
||||
cmp x5, #0
|
||||
bgt LoopPixel
|
||||
End:
|
||||
sub sp, sp, #16
|
||||
ldp x19, x20, [sp], #16
|
||||
ret
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,182 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
ConvolutionDepthwiseIndirectCPUKernel::~ConvolutionDepthwiseIndirectCPUKernel() {
|
||||
if (packed_weight_ != nullptr) {
|
||||
free(packed_weight_);
|
||||
packed_weight_ = nullptr;
|
||||
}
|
||||
if (zero_ptr_ != nullptr) {
|
||||
free(zero_ptr_);
|
||||
zero_ptr_ = nullptr;
|
||||
}
|
||||
if (indirect_buffer_ != nullptr) {
|
||||
free(indirect_buffer_);
|
||||
indirect_buffer_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::InitWeightBias() {
|
||||
// init weight: o, h, w, i; o == group, i == 1
|
||||
auto weight_tensor = in_tensors_[kWeightIndex];
|
||||
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
|
||||
int C4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||
int pack_weight_size = C4NUM * C4 * weight_tensor->Height() * weight_tensor->Width();
|
||||
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PackDepthwiseIndirectWeightC4Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(),
|
||||
weight_tensor->Batch());
|
||||
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(C4NUM * C4 * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
memset(bias_data_, 0, C4NUM * C4 * sizeof(float));
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
|
||||
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
||||
// malloc zero ptr
|
||||
zero_ptr_ = reinterpret_cast<float *>(malloc(C4NUM * C4 * sizeof(float)));
|
||||
if (zero_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(zero_ptr_, 0, C4NUM * C4 * sizeof(float));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::Init() {
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise Indirect fp32 InitWeightBias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::MallocIndirectBuffer() {
|
||||
// malloc indirect buffer
|
||||
step_w = conv_param_->dilation_w_ == 1 ? conv_param_->stride_w_ : conv_param_->kernel_w_;
|
||||
step_h =
|
||||
(conv_param_->kernel_h_ * conv_param_->kernel_w_) + (conv_param_->output_w_ - 1) * step_w * conv_param_->kernel_h_;
|
||||
int buffer_size = conv_param_->output_batch_ * conv_param_->output_h_ * step_h;
|
||||
indirect_buffer_ = reinterpret_cast<float **>(malloc(buffer_size * sizeof(float *)));
|
||||
if (indirect_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::ReSize() {
|
||||
if (indirect_buffer_ != nullptr) {
|
||||
free(indirect_buffer_);
|
||||
indirect_buffer_ = nullptr;
|
||||
}
|
||||
ConvolutionBaseCPUKernel::Init();
|
||||
auto ret = MallocIndirectBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionDepthwiseIndirect MallocIndirectBuffer failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::Execute(int task_id) {
|
||||
ConvDwIndirection(output_ptr_, indirect_buffer_, packed_weight_, reinterpret_cast<float *>(bias_data_), zero_ptr_,
|
||||
conv_param_, task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvDwIndirectRun(void *cdata, int task_id) {
|
||||
auto conv_dw = reinterpret_cast<ConvolutionDepthwiseIndirectCPUKernel *>(cdata);
|
||||
auto ret = conv_dw->Execute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionDepthwiseIndirectRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::MallocPackedInput() {
|
||||
int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
|
||||
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4;
|
||||
packed_input_ = reinterpret_cast<float *>(context_->allocator->Malloc(pack_input_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::Run() {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_ptr = reinterpret_cast<float *>(input_tensor->data_c());
|
||||
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||
auto ret = MallocPackedInput();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise fp32 indirect buffer MallocPackedInput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PackNHWCToNHWC4Fp32(input_ptr, packed_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
} else {
|
||||
packed_input_ = input_ptr;
|
||||
}
|
||||
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
|
||||
|
||||
ConvDwInitIndirection(indirect_buffer_, packed_input_, zero_ptr_, conv_param_, step_h, step_w);
|
||||
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ConvDwIndirectRun, this, conv_param_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvDwIndirectRun error: error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||
context_->allocator->Free(packed_input_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionDepthwiseIndirectCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionDepthwiseIndirectCPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int InitWeightBias();
|
||||
int Execute(int task_id);
|
||||
|
||||
private:
|
||||
int MallocIndirectBuffer();
|
||||
int MallocPackedInput();
|
||||
int step_w = 0;
|
||||
int step_h = 0;
|
||||
float **indirect_buffer_ = nullptr;
|
||||
float *zero_ptr_ = nullptr;
|
||||
float *packed_weight_ = nullptr;
|
||||
float *output_ptr_ = nullptr;
|
||||
float *packed_input_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
|
Loading…
Reference in new issue