!4505 [MS][LITE] optimize arm cpu fp16 op: modify pack function, judge input data type

Merge pull request !4505 from yangruoqi713/lite
pull/4505/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f19cf80a20

@ -16,6 +16,7 @@
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -177,10 +178,22 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() {
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_addr = reinterpret_cast<float *>(input_tensor->Data());
float16_t *input_addr;
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_addr =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
if (input_addr == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->Data()), input_addr, input_tensor->ElementsNum());
} else {
input_addr = reinterpret_cast<float16_t *>(input_tensor->Data());
}
// pack input: to nhwc8
PackNHWCFp32ToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
PackNHWCToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
ret = LiteBackendParallelLaunch(ConvDwFp16Run, this, conv_param_->thread_num_);
if (ret != RET_OK) {
@ -188,10 +201,13 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() {
return RET_ERROR;
}
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
PackNHWC8Fp16ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
auto output_addr = reinterpret_cast<float16_t *>(out_tensors_.at(kOutputIndex)->Data());
PackNHWC8ToNHWCFp16(packed_output_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
if (input_tensor->data_type() == kNumberTypeFloat32) {
context_->allocator->Free(input_addr);
}
return RET_OK;
}

@ -334,31 +334,57 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane,
}
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
int nhwc8_batch_offset = 0;
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
float16_t *dst_batch = dst + b * plane * c8_channel;
float *src_batch = src + b * plane * channel;
for (int i = 0; i < plane; i++) {
float16_t *dst_plane = dst_batch + i * c8_channel;
float *src_plane = src_batch + i * channel;
for (int c = 0; c < channel; c++) {
(dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c];
dst_plane[c] = (float16_t)(src_plane[c]);
}
}
nhwc8_batch_offset += nhwc8_batch_unit_offset;
}
}
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc_batch_unit_offset = channel * plane;
int nhwc_batch_offset = 0;
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
for (int b = 0; b < batch; b++) {
int batch_offset = b * c8 * C8NUM * plane;
float16_t *src_batch = src + b * plane * c8_channel;
float *dst_batch = dst + b * plane * channel;
for (int i = 0; i < plane; i++) {
float16_t *src_plane = src_batch + i * c8_channel;
float *dst_plane = dst_batch + i * channel;
for (int c = 0; c < channel; c++) {
(dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c];
dst_plane[c] = (float16_t)(src_plane[c]);
}
}
nhwc_batch_offset += nhwc_batch_unit_offset;
}
}
void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) {
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
for (int b = 0; b < batch; b++) {
float16_t *dst_batch = dst + b * plane * c8_channel;
float16_t *src_batch = src + b * plane * channel;
for (int i = 0; i < plane; i++) {
float16_t *dst_plane = dst_batch + i * c8_channel;
float16_t *src_plane = src_batch + i * channel;
memcpy(dst_plane, src_batch, channel * sizeof(float16_t));
}
}
}
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) {
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
for (int b = 0; b < batch; b++) {
float16_t *src_batch = src + b * plane * c8_channel;
float16_t *dst_batch = dst + b * plane * channel;
for (int i = 0; i < plane; i++) {
float16_t *src_plane = src_batch + i * c8_channel;
float16_t *dst_plane = dst_batch + i * channel;
memcpy(dst_plane, src_batch, channel * sizeof(float16_t));
}
}
}

@ -58,6 +58,10 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane,
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel);
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
#ifdef __cplusplus
}
#endif

Loading…
Cancel
Save