!7433 [MSLITE] fp16 deconv winograd

Merge pull request !7433 from ling/sr
pull/7433/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 3763e201b5

@ -40,6 +40,7 @@ typedef enum ErrorCodeFp32OpEnum {
typedef enum ErrorCodeFp16OpEnum {
NNACL_ERRCODE_OP_FP16_START = 20000,
NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR,
NNACL_ERRCODE_OP_FP16_END = 29999
} ErrorCodeFp16OpEnum;

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/common_func_fp16.h"
void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr,
size_t output_channel, size_t plane_size, size_t oc_stride, size_t hw_stride,
ActType act_type, int size) {
if (size == 0) {
return;
}
for (int oc = 0; oc < output_channel; oc++) {
int oc_div = oc / size, oc_mod = oc % size;
for (int hw = 0; hw < plane_size; hw++) {
int src_index = oc_div * size * hw_stride + hw * size + oc_mod;
int dst_index = hw * oc_stride + oc;
float16_t value = src_ptr_[src_index];
if (bias_ptr != NULL) {
value = value + bias_ptr[oc];
}
value = (act_type == ActType_Relu || act_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value);
value = (act_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value);
out_ptr[dst_index] = value;
}
}
return;
}
void PostConvFuncFp16C8(const float16_t *c8_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane,
size_t oc_stride, ActType act_type) {
size_t oc8mod = oc % C8NUM;
size_t oc8div = oc - oc8mod;
size_t stride_size = oc_stride * sizeof(float16_t);
PostFuncBiasReluC8Fp16(nhwc_out, c8_out, bias, oc8div, oc8mod, plane, stride_size, act_type);
return;
}
void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane,
size_t plane_stride, ActType act_type) {
PostConvFuncCommFp16(nhwc_out, c4_out, bias, oc, plane, oc, plane_stride, act_type, C4NUM);
}

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_COMMON_FUNC_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_COMMON_FUNC_FP16_H_
#include <arm_neon.h>
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
/* deconv common */
void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr,
size_t output_channel, size_t plane_size, size_t stride, ActType act_type);
void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod,
size_t plane_size, size_t stride, size_t relu_type);
/* deconv winograd */
void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t output_channel,
size_t plane_size, size_t plane_stride, ActType act_type);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_COMMON_FUNC_FP16_H_

@ -13,41 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/deconv_fp16.h"
void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr,
size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6,
int size) {
if (size == 0) {
return;
}
for (int oc = 0; oc < output_channel; oc++) {
int oc_div = oc / size, oc_mod = oc % size;
for (int hw = 0; hw < plane_size; hw++) {
int src_index = oc_div * size * plane_size + hw * size + oc_mod;
int dst_index = hw * stride + oc;
float16_t value = src_ptr_[src_index];
if (bias_ptr != NULL) {
value = value + bias_ptr[oc];
}
value = (is_relu || is_relu6) ? (MSMAX(0.f, value)) : (value);
value = (is_relu6) ? (MSMIN(6.f, value)) : (value);
out_ptr[dst_index] = value;
}
}
return;
}
void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr,
size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6) {
size_t oc8mod = output_channel % C8NUM;
size_t oc8div = output_channel - oc8mod;
size_t stride_size = stride * sizeof(float16_t);
size_t relu_type = is_relu ? 1 : 0;
relu_type = is_relu6 ? 3 : relu_type;
PostFuncBiasReluC8Fp16(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type);
return;
}
#include "nnacl/fp16/deconv_fp16.h"
int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel,
ConvParameter *conv_param) {
@ -112,7 +79,6 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
} /*ih*/
} /*oc8*/
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->act_type_);
return NNACL_OK;
}

@ -13,27 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_DECONV_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_DECONV_FP16_H_
#include <string.h>
#include <arm_neon.h>
#include <string.h>
#include "nnacl/conv_parameter.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/fp16/matmul_fp16.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp16/common_func_fp16.h"
#ifdef __cplusplus
extern "C" {
#endif
int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel,
ConvParameter *conv_param);
void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr,
size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6);
void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod,
size_t plane_size, size_t stride, size_t relu_type);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_DECONV_WINOGRAD_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_DECONV_WINOGRAD_FP16_H_
#include "nnacl/fp16/winograd_transform_fp16.h"
#ifdef __cplusplus
extern "C" {
#endif
int PackDeConvWgDataFp16(float16_t *nhwc_weight, DeConvComputeUnit *unit, ConvParameter *conv_param,
DeConvParam *deconv_param);
void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index, int calculate_count,
ConvParameter *conv_param, DeConvParam *deconv_param, int task_id);
void DeconvWgPostFp16(float16_t *tile_out, float16_t *nc4hw4_output, ConvParameter *conv_param,
DeConvParam *deconv_param, int calculate_count, int tile_index);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_DECONV_WINOGRAD_FP16_H_

@ -81,3 +81,51 @@ void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matri
}
}
}
void WinogradMatrixProductLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length) {
int unitStep = 4 * length;
for (int y = 0; y < h; ++y) {
float16_t *dstY = M + y * w * unitStep;
for (int x = 0; x < w; ++x) {
float16_t *dstX = dstY + x * unitStep;
const float16_t *srcX = S + x * unitStep;
memset(dstX, 0, unitStep * sizeof(float16_t));
for (int i = 0; i < k; ++i) {
float16_t b = B[i * h + y];
const float16_t *srcY = srcX + i * w * unitStep;
if (0.0f == b) {
continue;
}
for (int j = 0; j < unitStep; ++j) {
dstX[j] += srcY[j] * b;
}
}
}
}
}
// M = S * B , M = w*h * l, S = k*h * l, B = w*k
void WinogradMatrixProductRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length) {
int unitStep = 4 * length;
for (int y = 0; y < h; ++y) {
float16_t *dstY = M + y * w * unitStep;
const float16_t *srcY = S + y * k * unitStep;
for (int x = 0; x < w; ++x) {
float16_t *dstX = dstY + x * unitStep;
memset(dstX, 0, unitStep * sizeof(float16_t));
for (int i = 0; i < k; ++i) {
const float16_t *srcX = srcY + i * unitStep;
float16_t b = B[i * h + x];
if (0.0f == b) {
continue;
}
for (int j = 0; j < unitStep; ++j) {
dstX[j] += srcX[j] * b;
}
}
}
}
}

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_NNACL_FP16_MATRIX_FP16_H_
#include <arm_neon.h>
#include <string.h>
#ifdef __cplusplus
extern "C" {
@ -28,6 +29,13 @@ void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matri
const float16_t *bias, int m, int k, int n);
void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k,
int n, int in_channel);
void WinogradMatrixProductLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length);
void WinogradMatrixProductRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length);
#ifdef __cplusplus
}
#endif

@ -712,3 +712,102 @@ void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_d
out_tile_index++;
}
}
int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, float *matrix_g,
float *matrix_gt, int oc_block, int input_unit, int kernel_unit, int filter_channel,
int filter_batch, bool pack) {
// original weight format : ohwi
int oc_block_num = UP_DIV(filter_batch, oc_block);
int block_stride = filter_channel * oc_block;
int block_num_stride = block_stride * oc_block_num;
float16_t *matrix_gt_data_fp16 = (float16_t *)(malloc(input_unit * kernel_unit * sizeof(float16_t)));
if (matrix_gt_data_fp16 == NULL) {
return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
}
Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit * kernel_unit);
// trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T
// separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T
float16_t *tmp_data = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t)));
if (tmp_data == NULL) {
free(matrix_gt_data_fp16);
return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
}
float16_t *trans_out_data = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t)));
if (trans_out_data == NULL) {
free(tmp_data);
free(matrix_gt_data_fp16);
return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
}
#ifndef ENABLE_ARM64
float16_t *tmp_data1 = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t)));
if (tmp_data1 == NULL) {
free(tmp_data);
free(matrix_gt_data_fp16);
free(trans_out_data);
return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
}
float16_t *trans_out_data1 = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t)));
if (trans_out_data1 == NULL) {
free(tmp_data);
free(tmp_data1);
free(matrix_gt_data_fp16);
free(trans_out_data);
return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
}
#endif
int input_oz_offset = kernel_unit * kernel_unit * filter_channel;
for (int i = 0; i < filter_batch; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int output_oz_offset = out_c_block * block_stride + out_c_res;
#ifndef ENABLE_ARM64
// tmp_data = g * GT
MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit,
kernel_unit, input_unit, filter_channel);
// tmp_data1 = (tmp_data)T
PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit, input_unit, filter_channel);
// trans_out_data1 = tmp * GT
MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit, kernel_unit, input_unit,
filter_channel);
// trans_out_data = (trans_out_data1)T
PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit, input_unit, filter_channel);
#else
// tmp = (g * GT)T
MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit,
kernel_unit, input_unit, filter_channel);
// trans = (tmp * GT)T
MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit,
filter_channel);
#endif
if (pack) {
int in_offset = 0;
for (int j = 0; j < input_unit; ++j) {
for (int k = 0; k < input_unit; ++k) {
for (int c = 0; c < filter_channel; ++c) {
*(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
}
in_offset += filter_channel;
output_oz_offset += block_num_stride;
}
}
} else {
memcpy(winograd_data + i * filter_channel * input_unit * input_unit, trans_out_data,
filter_channel * input_unit * input_unit * sizeof(float16_t));
}
}
#ifndef ENABLE_ARM64
free(tmp_data1);
free(trans_out_data1);
#endif
free(tmp_data);
free(trans_out_data);
free(matrix_gt_data_fp16);
return NNACL_OK;
}

@ -19,9 +19,10 @@
#include <arm_neon.h>
#include <string.h>
#include "nnacl/fp16/pack_fp16.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/winograd_utils_fp16.h"
#include "nnacl/fp16/matrix_fp16.h"
#ifdef __cplusplus
extern "C" {
@ -49,6 +50,12 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param,
OutputTransFp16Func func);
// fp16 winograd weight trans
int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, float *matrix_g,
float *matrix_gt, int oc_block, int input_unit, int kernel_unit, int filter_channel,
int filter_batch, bool pack);
#ifdef __cplusplus
}
#endif

@ -15,23 +15,10 @@
*/
#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h"
#include "nnacl/fp16/matrix_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/pack_fp16.h"
#include "nnacl/fp16/winograd_transform_fp16.h"
#include "nnacl/fp16/winograd_utils_fp16.h"
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g,
@ -40,104 +27,9 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_
MS_LOG(ERROR) << "Divide by zero";
return RET_ERROR;
}
// original weight format : ohwi
auto channel_in = conv_param_->input_channel_;
auto channel_out = conv_param_->output_channel_;
int oc_block_num = UP_DIV(channel_out, oc_block);
int block_stride = channel_in * oc_block;
int block_num_stride = block_stride * oc_block_num;
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
if (matrix_gt_data_fp16 == nullptr) {
MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed.";
return RET_ERROR;
}
Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_);
// trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T
// separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T
auto tmp_data = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t)));
if (tmp_data == nullptr) {
free(matrix_gt_data_fp16);
MS_LOG(ERROR) << "malloc tmp_data failed.";
return RET_ERROR;
}
auto trans_out_data =
reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t)));
if (trans_out_data == nullptr) {
free(tmp_data);
free(matrix_gt_data_fp16);
MS_LOG(ERROR) << "malloc trans_out_data failed.";
return RET_ERROR;
}
#ifndef ENABLE_ARM64
auto tmp_data1 = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t)));
if (tmp_data1 == nullptr) {
free(tmp_data);
free(matrix_gt_data_fp16);
free(trans_out_data);
MS_LOG(ERROR) << "malloc tmp_data1 failed.";
return RET_ERROR;
}
auto trans_out_data1 =
reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t)));
if (trans_out_data1 == nullptr) {
free(tmp_data);
free(tmp_data1);
free(matrix_gt_data_fp16);
free(trans_out_data);
MS_LOG(ERROR) << "malloc trans_out_data1 failed.";
return RET_ERROR;
}
#endif
int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in;
for (int i = 0; i < channel_out; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int output_oz_offset = out_c_block * block_stride + out_c_res;
#ifndef ENABLE_ARM64
// tmp_data = g * GT
MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_,
kernel_unit_, input_unit_, channel_in);
// tmp_data1 = (tmp_data)T
PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in);
// trans_out_data1 = tmp * GT
MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit_, kernel_unit_, input_unit_,
channel_in);
// trans_out_data = (trans_out_data1)T
PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in);
#else
// tmp = (g * GT)T
MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_,
kernel_unit_, input_unit_, channel_in);
// trans = (tmp * GT)T
MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_,
channel_in);
#endif
int in_offset = 0;
for (int j = 0; j < input_unit_; ++j) {
for (int k = 0; k < input_unit_; ++k) {
for (int c = 0; c < channel_in; ++c) {
*(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
}
in_offset += channel_in;
output_oz_offset += block_num_stride;
}
}
}
#ifndef ENABLE_ARM64
free(tmp_data1);
free(trans_out_data1);
#endif
free(tmp_data);
free(trans_out_data);
free(matrix_gt_data_fp16);
return RET_OK;
return WinogradWeightTransformFp16(weight_data, trans_weight_, matrix_g, matrix_gt, oc_block, input_unit_,
kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_, true);
}
int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h"
#include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
@ -64,7 +65,7 @@ int DeConvolutionFp16CPUKernel::InitWeightBias() {
memset(bias_data_, 0, UP_ROUND(output_channel, C4NUM) * sizeof(float16_t));
if (in_tensors_.size() == 3) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->MutableData()),
reinterpret_cast<float16_t *>(bias_data_), conv_param_->output_channel_);
reinterpret_cast<float16_t *>(bias_data_), output_channel);
}
size_t weight_pack_size = input_channel * kernel_w * kernel_h * UP_ROUND(output_channel, C8NUM) * sizeof(float16_t);
@ -158,9 +159,10 @@ int DeConvolutionFp16CPUKernel::DoDeconv(int task_id) {
MatMulFp16(pack_input_, execute_weight_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buf, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_, oc * C8NUM * kernel_plane_, 0,
OutType_C8);
DeConvPostFp16(tmp_buf, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float16_t *>(bias_data_) + task_id * thread_stride_ * C8NUM,
execute_output_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
batch_output_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
return RET_OK;
}
@ -191,7 +193,10 @@ int DeConvolutionFp16CPUKernel::Run() {
}
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
RowMajor2Col16MajorFp16Opt(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_);
batch_input_ = execute_input_ + batch_index * conv_param_->input_channel_ * input_plane_;
batch_output_ = execute_output_ + batch_index * conv_param_->output_channel_ * output_plane_;
RowMajor2Col16MajorFp16Opt(batch_input_, pack_input_, input_plane_, conv_param_->input_channel_);
error_code = ParallelLaunch(this->context_->thread_pool_, DeConvFp16Run, this, thread_count_);
if (error_code != RET_OK) {
@ -229,7 +234,16 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
weight_tensor->SetData(dequant_weight);
}
auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel::LiteKernel *kernel;
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) &&
(conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1)) {
/* DeConvWinogradFp16CPUKernel */
kernel = new (std::nothrow) kernel::DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {

@ -17,17 +17,11 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_
#include <float.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "schema/model_generated.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "nnacl/fp16/deconv_fp16.h"
#include "nnacl/fp16/matmul_fp16.h"
#include "nnacl/fp16/pack_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
namespace mindspore::kernel {
class DeConvolutionFp16CPUKernel : public ConvolutionBaseFP16CPUKernel {
@ -65,6 +59,8 @@ class DeConvolutionFp16CPUKernel : public ConvolutionBaseFP16CPUKernel {
float16_t *pack_input_;
float16_t *pack_output_;
float16_t *tmp_buffer_;
float16_t *batch_input_;
float16_t *batch_output_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_WINOGRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_WINOGRAD_H_
#include <vector>
#include "include/errorcode.h"
#include "nnacl/fp16/common_func_fp16.h"
#include "nnacl/fp16/deconv_winograd_fp16.h"
#include "nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
namespace mindspore::kernel {
class DeConvWinogradFp16CPUKernel : public ConvolutionBaseFP16CPUKernel {
public:
DeConvWinogradFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {
deconv_param_ = new DeConvParam();
for (auto &wg : deconv_param_->a_buffer_) {
wg.buf_init_ = false;
}
}
~DeConvWinogradFp16CPUKernel() override;
int Init() override;
int Run() override;
int ReSize() override;
public:
int DoDeconv(int task_id);
int DeDeconvPost(int task_id);
private:
int InitComputeParam();
int InitDataParam();
int InitParameter();
void FreeDeconvParam();
void FreeResizeBuf();
private:
DeConvParam *deconv_param_;
std::mutex lock_;
float16_t *nhwc_input_ = nullptr;
float16_t *nhwc_output_ = nullptr;
float16_t *nc4hw4_output_ = nullptr;
float16_t *tile_input_ = nullptr;
float16_t *tile_output_ = nullptr;
int thread_num_hw_;
int thread_stride_hw_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_WINOGRAD_H_

@ -17,16 +17,11 @@
#include "src/runtime/kernel/arm/fp32/deconvolution_winograd.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DeConv2D;
using mindspore::schema::Format::Format_NHWC;
namespace mindspore::kernel {
DeConvolutionWinogradCPUKernel::~DeConvolutionWinogradCPUKernel() {
FreeResizeBuf();
FreeDeconvParam();
@ -352,10 +347,7 @@ int DeConvolutionWinogradCPUKernel::Run() {
nhwc_output_ = src_out + batch_index * deconv_param_->output_plane_ * conv_param_->output_channel_;
::memset(nc4hw4_output_, 0, deconv_param_->output_plane_ * deconv_param_->oc_div4_ * C4NUM * sizeof(float));
for (int i = 0; i < deconv_param_->thread_num_; i++) {
DoDeconv(i);
}
// ParallelLaunch(this->context_->thread_pool_, DeConvWgFp32Run, this, deconv_param_->thread_num_);
ParallelLaunch(this->context_->thread_pool_, DeConvWgFp32Run, this, deconv_param_->thread_num_);
/*post bias activate and nhwc */
ParallelLaunch(this->context_->thread_pool_, DeConvWgPostFp32Run, this, thread_num_hw_);

Loading…
Cancel
Save