parent
c78683a411
commit
ca87533cd7
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,90 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "nnacl/fp32/adder_fp32.h"
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "nnacl/fp32/common_func_fp32.h"
|
||||||
|
#include "nnacl/fp32/matmul_fp32.h"
|
||||||
|
|
||||||
|
void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||||
|
int col, int stride) {
|
||||||
|
for (int r = 0; r < row; r++) {
|
||||||
|
for (int c = 0; c < col; c++) {
|
||||||
|
int r12div = r / 12, r12mod = r % 12;
|
||||||
|
int c4div = c / 4, c4mod = c % 4;
|
||||||
|
size_t ci = r * stride + c;
|
||||||
|
float value = 0;
|
||||||
|
for (int d = 0; d < deep; d++) {
|
||||||
|
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
|
||||||
|
size_t bi = c4div * deep * 4 + d * 4 + c4mod;
|
||||||
|
value += fabsf(a[ai] - b[bi]);
|
||||||
|
}
|
||||||
|
value = -value;
|
||||||
|
if (bias != NULL) value += bias[c];
|
||||||
|
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||||
|
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||||
|
dst[ci] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
|
||||||
|
size_t stride) {
|
||||||
|
#ifdef ENABLE_ARM64
|
||||||
|
AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride);
|
||||||
|
#else
|
||||||
|
Adder12x4(a, b, c, bias, act_type, deep, row, col, stride);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
|
||||||
|
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) {
|
||||||
|
int out_channel = conv_param->output_channel_;
|
||||||
|
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||||
|
int output_count = conv_param->output_h_ * conv_param->output_w_;
|
||||||
|
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||||
|
const int cal_num = C4NUM;
|
||||||
|
#else
|
||||||
|
const int cal_num = C12NUM;
|
||||||
|
#endif
|
||||||
|
int output_tile_count = UP_DIV(output_count, cal_num);
|
||||||
|
|
||||||
|
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||||
|
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||||
|
int out_batch_offset = b * out_channel * output_count;
|
||||||
|
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||||
|
int start_index = thread_id * cal_num;
|
||||||
|
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
|
||||||
|
float *gemm_input = packed_input + task_id * deep * cal_num;
|
||||||
|
float *col_major_gemm_input = col_major_input + task_id * deep * cal_num;
|
||||||
|
size_t packed_input_size = deep * cal_num * sizeof(float);
|
||||||
|
memset(gemm_input, 0, packed_input_size);
|
||||||
|
memset(col_major_gemm_input, 0, packed_input_size);
|
||||||
|
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||||
|
|
||||||
|
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
|
||||||
|
float *gemm_output = output_data + out_offset;
|
||||||
|
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||||
|
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||||
|
#else
|
||||||
|
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||||
|
#endif
|
||||||
|
AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num,
|
||||||
|
out_channel, out_channel);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_NNACL_FP32_ADDER_H_
|
||||||
|
#define MINDSPORE_LITE_NNACL_FP32_ADDER_H_
|
||||||
|
|
||||||
|
#ifdef ENABLE_NEON
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#endif
|
||||||
|
#include "nnacl/pack.h"
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
#include "nnacl/common_func.h"
|
||||||
|
#include "nnacl/conv_parameter.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_ARM64
|
||||||
|
void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||||
|
int col, size_t stride);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
|
||||||
|
size_t stride);
|
||||||
|
|
||||||
|
void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
|
||||||
|
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_NNACL_FP32_ADDER_H_
|
@ -0,0 +1,133 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/runtime/kernel/arm/fp32/adder_fp32.h"
|
||||||
|
#include "src/kernel_registry.h"
|
||||||
|
#include "src/runtime/runtime_api.h"
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
#include "schema/model_generated.h"
|
||||||
|
#include "nnacl/fp32/adder_fp32.h"
|
||||||
|
#include "nnacl/fp32/matmul_fp32.h"
|
||||||
|
|
||||||
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||||
|
using mindspore::lite::KernelRegistrar;
|
||||||
|
using mindspore::lite::RET_ERROR;
|
||||||
|
using mindspore::lite::RET_INFER_INVALID;
|
||||||
|
using mindspore::lite::RET_OK;
|
||||||
|
using mindspore::schema::PrimitiveType_Adder;
|
||||||
|
using mindspore::schema::Format::Format_NHWC;
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
int AdderCPUKernel::InitWeightBias() {
|
||||||
|
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||||
|
int kernel_h = filter_tensor->Height();
|
||||||
|
int kernel_w = filter_tensor->Width();
|
||||||
|
int in_channel = filter_tensor->Channel();
|
||||||
|
int out_channel = filter_tensor->Batch();
|
||||||
|
conv_param_->input_channel_ = in_channel;
|
||||||
|
conv_param_->output_channel_ = out_channel;
|
||||||
|
int kernel_plane = kernel_h * kernel_w;
|
||||||
|
const int oc_block = C4NUM;
|
||||||
|
int oc_block_num = UP_DIV(out_channel, C4NUM);
|
||||||
|
int pack_weight_size = oc_block_num * oc_block * in_channel * kernel_plane;
|
||||||
|
|
||||||
|
auto origin_weight = reinterpret_cast<float *>(filter_tensor->MutableData());
|
||||||
|
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||||
|
if (packed_weight_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc packed weight failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
|
||||||
|
RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
|
||||||
|
|
||||||
|
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
|
||||||
|
if (bias_data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc bias failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
|
||||||
|
|
||||||
|
if (in_tensors_.size() == kInputSize2) {
|
||||||
|
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
|
||||||
|
memcpy(bias_data_, ori_bias, out_channel * sizeof(float));
|
||||||
|
} else {
|
||||||
|
MS_ASSERT(in_tensors_.size() == kInputSize1);
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdderCPUKernel::RunImpl(int task_id) {
|
||||||
|
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||||
|
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c());
|
||||||
|
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
|
||||||
|
AdderFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), col_major_input_,
|
||||||
|
output_addr, task_id, conv_param_);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdderImpl(void *cdata, int task_id) {
|
||||||
|
auto adder = reinterpret_cast<AdderCPUKernel *>(cdata);
|
||||||
|
auto error_code = adder->RunImpl(task_id);
|
||||||
|
if (error_code != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Adder Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdderCPUKernel::Run() {
|
||||||
|
auto ret = InitTmpBuffer();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
int error_code = ParallelLaunch(this->context_->thread_pool_, AdderImpl, this, thread_count_);
|
||||||
|
if (error_code != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adder error error_code[" << error_code << "]";
|
||||||
|
FreeTmpBuffer();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
FreeTmpBuffer();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel::LiteKernel *CpuAdderFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||||
|
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
|
||||||
|
const InnerContext *ctx, const kernel::KernelKey &desc,
|
||||||
|
const mindspore::lite::PrimitiveC *primitive) {
|
||||||
|
MS_ASSERT(op_parameter != nullptr);
|
||||||
|
MS_ASSERT(desc.type == schema::PrimitiveType_Adder);
|
||||||
|
MS_ASSERT(desc.data_type == kNumberTypeFloat32);
|
||||||
|
kernel::LiteKernel *kernel = new (std::nothrow) kernel::AdderCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||||
|
if (kernel == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||||
|
free(op_parameter);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = kernel->Init();
|
||||||
|
if (ret != RET_OK && ret != RET_INFER_INVALID) {
|
||||||
|
delete kernel;
|
||||||
|
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
|
||||||
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adder, CpuAdderFp32KernelCreator)
|
||||||
|
} // namespace mindspore::kernel
|
@ -0,0 +1,41 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "src/lite_kernel.h"
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
#include "src/runtime/kernel/arm/fp32/convolution_fp32.h"
|
||||||
|
#include "nnacl/fp32/conv_fp32.h"
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
class AdderCPUKernel : public ConvolutionCPUKernel {
|
||||||
|
public:
|
||||||
|
AdderCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||||
|
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||||
|
const mindspore::lite::PrimitiveC *primitive)
|
||||||
|
: ConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||||
|
~AdderCPUKernel() override = default;
|
||||||
|
|
||||||
|
int InitWeightBias() override;
|
||||||
|
int Run() override;
|
||||||
|
int RunImpl(int task_id) override;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
|
Loading…
Reference in new issue