pull/5423/head
parent
e6112ed1ba
commit
8d06c2b8be
@ -0,0 +1,169 @@
|
|||||||
|
#ifdef __aarch64__
|
||||||
|
|
||||||
|
.text
|
||||||
|
.align 5
|
||||||
|
.global ConvDwInt8PostAlign4
|
||||||
|
#ifndef __APPLE__
|
||||||
|
.type ConvDwInt8PostAlign4, %function
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
|
||||||
|
// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
|
||||||
|
// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier,
|
||||||
|
// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max
|
||||||
|
|
||||||
|
ConvDwInt8PostAlign4:
|
||||||
|
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||||
|
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||||
|
// x19 ~ x29 should be also preserved
|
||||||
|
// whereas our coding style do not permit such amount of parameters
|
||||||
|
ldr x8, [sp]
|
||||||
|
|
||||||
|
dup v26.4s, w5
|
||||||
|
dup v27.4s, w4
|
||||||
|
dup v28.4s, w6
|
||||||
|
|
||||||
|
dup v29.4s, w3
|
||||||
|
dup v30.4s, w7
|
||||||
|
dup v31.4s, w8
|
||||||
|
|
||||||
|
cmp x2, 16
|
||||||
|
blt LoopDepth8
|
||||||
|
|
||||||
|
LoopDepth16:
|
||||||
|
ld1 {v0.4s}, [x1], #16
|
||||||
|
ld1 {v1.4s}, [x1], #16
|
||||||
|
ld1 {v2.4s}, [x1], #16
|
||||||
|
ld1 {v3.4s}, [x1], #16
|
||||||
|
|
||||||
|
sqshl v0.4s, v0.4s, v26.4s
|
||||||
|
sqshl v1.4s, v1.4s, v26.4s
|
||||||
|
sqshl v2.4s, v2.4s, v26.4s
|
||||||
|
sqshl v3.4s, v3.4s, v26.4s
|
||||||
|
|
||||||
|
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||||
|
sqrdmulh v1.4s, v1.4s, v27.4s
|
||||||
|
sqrdmulh v2.4s, v2.4s, v27.4s
|
||||||
|
sqrdmulh v3.4s, v3.4s, v27.4s
|
||||||
|
|
||||||
|
and v16.16b, v28.16b, v0.16b
|
||||||
|
sshr v16.4s, v16.4s, #31
|
||||||
|
sqadd v0.4s, v0.4s, v16.4s
|
||||||
|
srshl v0.4s, v0.4s, v28.4s
|
||||||
|
and v17.16b, v28.16b, v1.16b
|
||||||
|
sshr v17.4s, v17.4s, #31
|
||||||
|
sqadd v1.4s, v1.4s, v17.4s
|
||||||
|
srshl v1.4s, v1.4s, v28.4s
|
||||||
|
and v18.16b, v28.16b, v2.16b
|
||||||
|
sshr v18.4s, v18.4s, #31
|
||||||
|
sqadd v2.4s, v2.4s, v18.4s
|
||||||
|
srshl v2.4s, v2.4s, v28.4s
|
||||||
|
and v19.16b, v28.16b, v3.16b
|
||||||
|
sshr v19.4s, v19.4s, #31
|
||||||
|
sqadd v3.4s, v3.4s, v19.4s
|
||||||
|
srshl v3.4s, v3.4s, v28.4s
|
||||||
|
|
||||||
|
add v0.4s, v0.4s, v29.4s
|
||||||
|
add v1.4s, v1.4s, v29.4s
|
||||||
|
add v2.4s, v2.4s, v29.4s
|
||||||
|
add v3.4s, v3.4s, v29.4s
|
||||||
|
|
||||||
|
smax v0.4s, v0.4s, v30.4s
|
||||||
|
smax v1.4s, v1.4s, v30.4s
|
||||||
|
smax v2.4s, v2.4s, v30.4s
|
||||||
|
smax v3.4s, v3.4s, v30.4s
|
||||||
|
|
||||||
|
smin v0.4s, v0.4s, v31.4s
|
||||||
|
smin v1.4s, v1.4s, v31.4s
|
||||||
|
smin v2.4s, v2.4s, v31.4s
|
||||||
|
smin v3.4s, v3.4s, v31.4s
|
||||||
|
|
||||||
|
sqxtn v0.4h, v0.4s
|
||||||
|
sqxtn v1.4h, v1.4s
|
||||||
|
sqxtn v2.4h, v2.4s
|
||||||
|
sqxtn v3.4h, v3.4s
|
||||||
|
|
||||||
|
sqxtn v0.8b, v0.8h
|
||||||
|
sqxtn v1.8b, v1.8h
|
||||||
|
sqxtn v2.8b, v2.8h
|
||||||
|
sqxtn v3.8b, v3.8h
|
||||||
|
|
||||||
|
st1 {v0.s}[0], [x0], #4
|
||||||
|
st1 {v1.s}[0], [x0], #4
|
||||||
|
st1 {v2.s}[0], [x0], #4
|
||||||
|
st1 {v3.s}[0], [x0], #4
|
||||||
|
|
||||||
|
sub x2, x2, #16
|
||||||
|
cmp x2, #16
|
||||||
|
bge LoopDepth16
|
||||||
|
|
||||||
|
LoopDepth8:
|
||||||
|
cmp x2, #8
|
||||||
|
blt LoopDepth4
|
||||||
|
ld1 {v0.4s}, [x1], #16
|
||||||
|
ld1 {v1.4s}, [x1], #16
|
||||||
|
|
||||||
|
sqshl v0.4s, v0.4s, v26.4s
|
||||||
|
sqshl v1.4s, v1.4s, v26.4s
|
||||||
|
|
||||||
|
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||||
|
sqrdmulh v1.4s, v1.4s, v27.4s
|
||||||
|
|
||||||
|
and v16.16b, v28.16b, v0.16b
|
||||||
|
sshr v16.4s, v16.4s, #31
|
||||||
|
sqadd v0.4s, v0.4s, v16.4s
|
||||||
|
srshl v0.4s, v0.4s, v28.4s
|
||||||
|
and v17.16b, v28.16b, v1.16b
|
||||||
|
sshr v17.4s, v17.4s, #31
|
||||||
|
sqadd v1.4s, v1.4s, v17.4s
|
||||||
|
srshl v1.4s, v1.4s, v28.4s
|
||||||
|
|
||||||
|
add v0.4s, v0.4s, v29.4s
|
||||||
|
add v1.4s, v1.4s, v29.4s
|
||||||
|
|
||||||
|
smax v0.4s, v0.4s, v30.4s
|
||||||
|
smax v1.4s, v1.4s, v30.4s
|
||||||
|
|
||||||
|
smin v0.4s, v0.4s, v31.4s
|
||||||
|
smin v1.4s, v1.4s, v31.4s
|
||||||
|
|
||||||
|
sqxtn v0.4h, v0.4s
|
||||||
|
sqxtn v1.4h, v1.4s
|
||||||
|
|
||||||
|
sqxtn v0.8b, v0.8h
|
||||||
|
sqxtn v1.8b, v1.8h
|
||||||
|
|
||||||
|
st1 {v0.s}[0], [x0], #4
|
||||||
|
st1 {v1.s}[0], [x0], #4
|
||||||
|
|
||||||
|
sub x2, x2, #8
|
||||||
|
cmp x2, #8
|
||||||
|
bge LoopDepth8
|
||||||
|
|
||||||
|
LoopDepth4:
|
||||||
|
cmp x2, #4
|
||||||
|
blt End
|
||||||
|
ld1 {v0.4s}, [x1], #16
|
||||||
|
|
||||||
|
sqshl v0.4s, v0.4s, v26.4s
|
||||||
|
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||||
|
|
||||||
|
and v16.16b, v28.16b, v0.16b
|
||||||
|
sshr v16.4s, v16.4s, #31
|
||||||
|
sqadd v0.4s, v0.4s, v16.4s
|
||||||
|
srshl v0.4s, v0.4s, v28.4s
|
||||||
|
|
||||||
|
add v0.4s, v0.4s, v29.4s
|
||||||
|
smax v0.4s, v0.4s, v30.4s
|
||||||
|
smin v0.4s, v0.4s, v31.4s
|
||||||
|
|
||||||
|
sqxtn v0.4h, v0.4s
|
||||||
|
sqxtn v0.8b, v0.8h
|
||||||
|
|
||||||
|
st1 {v0.s}[0], [x0], #4
|
||||||
|
|
||||||
|
sub x2, x2, #4
|
||||||
|
bge LoopDepth4
|
||||||
|
End:
|
||||||
|
ret
|
||||||
|
#endif
|
@ -0,0 +1,122 @@
|
|||||||
|
#ifdef __aarch64__
|
||||||
|
|
||||||
|
.text
|
||||||
|
.align 5
|
||||||
|
.global ConvDwInt8Row
|
||||||
|
#ifndef __APPLE__
|
||||||
|
.type ConvDwInt8Row, %function
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
|
||||||
|
// int output_channel, int input_step, int8_t input_zp)
|
||||||
|
// x0: output_ptr, x1: input_ptr, x2: weight_ptr, x3: num_pixels,
|
||||||
|
// x4: output_channel, x5: input_step, x6: input_zp
|
||||||
|
//
|
||||||
|
ConvDwInt8Row:
|
||||||
|
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||||
|
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||||
|
// x19 ~ x29 should be also preserved
|
||||||
|
// whereas our coding style do not permit such amount of parameters
|
||||||
|
cmp x3, #0
|
||||||
|
beq End
|
||||||
|
|
||||||
|
mov x10, x0
|
||||||
|
|
||||||
|
dup v31.8b, w6
|
||||||
|
|
||||||
|
LoopOutPixel:
|
||||||
|
mov x7, x1
|
||||||
|
mov x8, x2
|
||||||
|
mov x9, x4
|
||||||
|
|
||||||
|
LoopDepth16In:
|
||||||
|
cmp x9, #16
|
||||||
|
blt L8
|
||||||
|
sub x9, x9, #16
|
||||||
|
|
||||||
|
ld1 {v0.8b, v1.8b}, [x7], #16
|
||||||
|
ld1 {v2.8h, v3.8h}, [x8], #32
|
||||||
|
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||||
|
|
||||||
|
ssubl v20.8h, v0.8b, v31.8b
|
||||||
|
smlal v16.4s, v20.4h, v2.4h
|
||||||
|
smlal2 v17.4s, v20.8h, v2.8h
|
||||||
|
|
||||||
|
|
||||||
|
cmp x9, #16
|
||||||
|
blt LoopDepth16Out
|
||||||
|
LoopDepth16:
|
||||||
|
|
||||||
|
st1 {v16.4s, v17.4s}, [x10], #32
|
||||||
|
ld1 {v18.4s, v19.4s}, [x0], #32
|
||||||
|
ssubl v21.8h, v1.8b, v31.8b
|
||||||
|
smlal v18.4s, v21.4h, v3.4h
|
||||||
|
smlal2 v19.4s, v21.8h, v3.8h
|
||||||
|
st1 {v18.4s, v19.4s}, [x10], #32
|
||||||
|
|
||||||
|
ld1 {v0.8b, v1.8b}, [x7], #16
|
||||||
|
ld1 {v2.8h, v3.8h}, [x8], #32
|
||||||
|
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||||
|
|
||||||
|
ssubl v20.8h, v0.8b, v31.8b
|
||||||
|
smlal v16.4s, v20.4h, v2.4h
|
||||||
|
smlal2 v17.4s, v20.8h, v2.8h
|
||||||
|
|
||||||
|
sub x9, x9, #16
|
||||||
|
cmp x9, #16
|
||||||
|
bge LoopDepth16
|
||||||
|
|
||||||
|
LoopDepth16Out:
|
||||||
|
|
||||||
|
st1 {v16.4s, v17.4s}, [x10], #32
|
||||||
|
ld1 {v18.4s, v19.4s}, [x0], #32
|
||||||
|
ssubl v21.8h, v1.8b, v31.8b
|
||||||
|
smlal v18.4s, v21.4h, v3.4h
|
||||||
|
smlal2 v19.4s, v21.8h, v3.8h
|
||||||
|
st1 {v18.4s, v19.4s}, [x10], #32
|
||||||
|
|
||||||
|
L8:
|
||||||
|
cmp x9, #8
|
||||||
|
blt L0
|
||||||
|
|
||||||
|
LoopDepth8:
|
||||||
|
ld1 {v0.8b}, [x7], #8
|
||||||
|
ld1 {v2.8h}, [x8], #16
|
||||||
|
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||||
|
|
||||||
|
ssubl v20.8h, v0.8b, v31.8b
|
||||||
|
smlal v16.4s, v20.4h, v2.4h
|
||||||
|
smlal2 v17.4s, v20.8h, v2.8h
|
||||||
|
st1 {v16.4s, v17.4s}, [x10], #32
|
||||||
|
|
||||||
|
sub x9, x9, #8
|
||||||
|
cmp x9, #8
|
||||||
|
bge LoopDepth8
|
||||||
|
|
||||||
|
L0:
|
||||||
|
cmp x9, #0
|
||||||
|
beq Loop16LineEnd
|
||||||
|
|
||||||
|
LoopDepth0:
|
||||||
|
ldrsb w14, [x7], #1
|
||||||
|
ldrsh w15, [x8], #2
|
||||||
|
ldr w16, [x0], #4
|
||||||
|
add w14, w14, w6
|
||||||
|
|
||||||
|
sxth w14, w14
|
||||||
|
madd w14, w14, w15, w16
|
||||||
|
str w14, [x10], #4
|
||||||
|
|
||||||
|
subs x9, x9, #1
|
||||||
|
bne LoopDepth0
|
||||||
|
|
||||||
|
Loop16LineEnd:
|
||||||
|
|
||||||
|
subs x3, x3, #1
|
||||||
|
add x1, x1, x5
|
||||||
|
bne LoopOutPixel
|
||||||
|
|
||||||
|
End:
|
||||||
|
ret
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,182 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.h"
|
||||||
|
#include "schema/model_generated.h"
|
||||||
|
#include "src/kernel_registry.h"
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
#include "nnacl/int8/conv_depthwise_int8.h"
|
||||||
|
#include "src/runtime/runtime_api.h"
|
||||||
|
|
||||||
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||||
|
using mindspore::lite::KernelRegistrar;
|
||||||
|
using mindspore::lite::RET_ERROR;
|
||||||
|
using mindspore::lite::RET_OK;
|
||||||
|
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
ConvolutionDepthwiseSWInt8CPUKernel::~ConvolutionDepthwiseSWInt8CPUKernel() {
|
||||||
|
if (sliding != nullptr) {
|
||||||
|
delete sliding;
|
||||||
|
sliding = nullptr;
|
||||||
|
}
|
||||||
|
if (packed_weight_ != nullptr) {
|
||||||
|
free(packed_weight_);
|
||||||
|
packed_weight_ = nullptr;
|
||||||
|
}
|
||||||
|
FreeQuantParam();
|
||||||
|
}
|
||||||
|
|
||||||
|
int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() {
|
||||||
|
// init weight, int8 -> int16
|
||||||
|
// o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1
|
||||||
|
auto weight_tensor = in_tensors_[kWeightIndex];
|
||||||
|
auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->Data());
|
||||||
|
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||||
|
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
|
||||||
|
packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t)));
|
||||||
|
if (packed_weight_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
|
||||||
|
weight_tensor->Batch(), &(conv_param_->conv_quant_arg_));
|
||||||
|
|
||||||
|
bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t)));
|
||||||
|
if (bias_data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t));
|
||||||
|
if (in_tensors_.size() == kInputSize2) {
|
||||||
|
auto bias_tensor = in_tensors_.at(kBiasIndex);
|
||||||
|
auto ori_bias = reinterpret_cast<int32_t *>(bias_tensor->Data());
|
||||||
|
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
conv_param_->thread_num_ = MSMIN(thread_count_, OC4);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() {
|
||||||
|
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM *
|
||||||
|
UP_DIV(conv_param_->input_channel_, 4);
|
||||||
|
packed_input_ = reinterpret_cast<int16_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int16_t)));
|
||||||
|
if (packed_input_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||||
|
need_align_ = true;
|
||||||
|
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM *
|
||||||
|
UP_DIV(conv_param_->output_channel_, C4NUM);
|
||||||
|
packed_output_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_output_size * sizeof(int8_t)));
|
||||||
|
if (packed_input_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ConvolutionDepthwiseSWInt8CPUKernel::Init() {
|
||||||
|
sliding = new (std::nothrow) SlidingWindowParam;
|
||||||
|
if (sliding == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new sliding window param.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (!InferShapeDone()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
return ReSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() {
|
||||||
|
ConvolutionBaseCPUKernel::Init();
|
||||||
|
InitSlidingParamConvDw(sliding, conv_param_, C4NUM);
|
||||||
|
|
||||||
|
auto ret = ConvolutionBaseCPUKernel::SetQuantParam();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Set quant param failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
ret = InitWeightBias();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ConvolutionDepthwiseSWInt8CPUKernel::Execute(int task_id) {
|
||||||
|
ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
|
||||||
|
sliding, task_id);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ConvDwSWInt8Run(void *cdata, int task_id) {
|
||||||
|
auto conv_dw_int8 = reinterpret_cast<ConvolutionDepthwiseSWInt8CPUKernel *>(cdata);
|
||||||
|
auto ret = conv_dw_int8->Execute(task_id);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "ConvolutionDepthwiseSWInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ConvolutionDepthwiseSWInt8CPUKernel::Run() {
|
||||||
|
if (conv_param_->input_channel_ != conv_param_->output_channel_) {
|
||||||
|
MS_LOG(ERROR) << "Only support input channel equals output channel.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto ret = Prepare();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Prepare failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = InitBuffer();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Depthwise int8 ReSize error!";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||||
|
auto input_addr = reinterpret_cast<int8_t *>(input_tensor->Data());
|
||||||
|
PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_);
|
||||||
|
|
||||||
|
auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->Data());
|
||||||
|
if (!need_align_) {
|
||||||
|
packed_output_ = output_addr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ConvDwSWInt8Run, this, conv_param_->thread_num_);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "ConvDwSWInt8Run error: error_code[" << ret << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (need_align_) {
|
||||||
|
PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_,
|
||||||
|
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||||
|
context_->allocator->Free(packed_output_);
|
||||||
|
}
|
||||||
|
context_->allocator->Free(packed_input_);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore::kernel
|
@ -0,0 +1,51 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_INT8_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_INT8_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "src/lite_kernel.h"
|
||||||
|
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||||
|
#include "nnacl/fp32/conv_depthwise.h"
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel {
|
||||||
|
public:
|
||||||
|
ConvolutionDepthwiseSWInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
|
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||||
|
const mindspore::lite::PrimitiveC *primitive)
|
||||||
|
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||||
|
~ConvolutionDepthwiseSWInt8CPUKernel() override;
|
||||||
|
|
||||||
|
int Init() override;
|
||||||
|
int ReSize() override;
|
||||||
|
int Run() override;
|
||||||
|
|
||||||
|
int InitWeightBias();
|
||||||
|
int InitBuffer();
|
||||||
|
int Execute(int task_id);
|
||||||
|
|
||||||
|
private:
|
||||||
|
SlidingWindowParam *sliding = nullptr;
|
||||||
|
int16_t *packed_weight_ = nullptr;
|
||||||
|
int16_t *packed_input_ = nullptr;
|
||||||
|
int8_t *packed_output_ = nullptr;
|
||||||
|
bool need_align_ = false;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_INT8_H_
|
Loading…
Reference in new issue