[MSLITE] arm32 prelu

pull/12213/head
ling 4 years ago
parent ab0010acfc
commit f12d3f3896

@ -18,89 +18,70 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int task_id) { void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) {
float *negetive_slope_value = prelu_param_->slope_; #ifdef ENABLE_ARM
int c4 = prelu_param_->channel_num_ / C4NUM; float32x4_t zero_value = vdupq_n_f32(0);
#endif
int plane_tile = plane / TILE_NUM * TILE_NUM;
int channel_num = prelu_param_->channel_num_; int channel_num = prelu_param_->channel_num_;
for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) { int plane_index = 0;
float *input_ptr = input + j * TILE_NUM * channel_num; for (; plane_index < plane_tile; plane_index += TILE_NUM) {
float *output_ptr = input_ptr; float *in_plane_ptr = input + plane_index * channel_num;
#ifdef ENABLE_ARM64 float *out_plane_ptr = output + plane_index * channel_num;
for (int i = 0; i < c4; i++) { int channel_index = 0;
int c_offset = i * C4NUM; #ifdef ENABLE_ARM
float32x4_t slope_value = vld1q_f32(negetive_slope_value + c_offset); float *negetive_slope_value = prelu_param_->slope_;
float32x4_t v1 = vld1q_f32(input_ptr + c_offset); int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM;
float32x4_t v2 = vld1q_f32(input_ptr + c_offset + channel_num); for (; channel_index < div_channel; channel_index += C4NUM) {
float32x4_t v3 = vld1q_f32(input_ptr + c_offset + 2 * channel_num); float32x4_t slope_value = vld1q_f32(negetive_slope_value + channel_index);
float32x4_t v4 = vld1q_f32(input_ptr + c_offset + 3 * channel_num); float32x4_t v1 = vld1q_f32(in_plane_ptr + channel_index + 0 * channel_num);
float32x4_t v5 = vld1q_f32(input_ptr + c_offset + 4 * channel_num); float32x4_t v2 = vld1q_f32(in_plane_ptr + channel_index + 1 * channel_num);
float32x4_t v6 = vld1q_f32(input_ptr + c_offset + 5 * channel_num); float32x4_t v3 = vld1q_f32(in_plane_ptr + channel_index + 2 * channel_num);
float32x4_t v7 = vld1q_f32(input_ptr + c_offset + 6 * channel_num); float32x4_t v4 = vld1q_f32(in_plane_ptr + channel_index + 3 * channel_num);
float32x4_t v8 = vld1q_f32(input_ptr + c_offset + 7 * channel_num); float32x4_t v5 = vld1q_f32(in_plane_ptr + channel_index + 4 * channel_num);
float32x4_t v6 = vld1q_f32(in_plane_ptr + channel_index + 5 * channel_num);
float32x4_t t1 = vmulq_f32(v1, slope_value); float32x4_t v7 = vld1q_f32(in_plane_ptr + channel_index + 6 * channel_num);
float32x4_t t2 = vmulq_f32(v2, slope_value); float32x4_t v8 = vld1q_f32(in_plane_ptr + channel_index + 7 * channel_num);
float32x4_t t3 = vmulq_f32(v3, slope_value);
float32x4_t t4 = vmulq_f32(v4, slope_value);
float32x4_t t5 = vmulq_f32(v5, slope_value);
float32x4_t t6 = vmulq_f32(v6, slope_value);
float32x4_t t7 = vmulq_f32(v7, slope_value);
float32x4_t t8 = vmulq_f32(v8, slope_value);
uint32x4_t flag1 = vclezq_f32(v1);
uint32x4_t flag2 = vclezq_f32(v2);
uint32x4_t flag3 = vclezq_f32(v3);
uint32x4_t flag4 = vclezq_f32(v4);
uint32x4_t flag5 = vclezq_f32(v5);
uint32x4_t flag6 = vclezq_f32(v6);
uint32x4_t flag7 = vclezq_f32(v7);
uint32x4_t flag8 = vclezq_f32(v8);
float32x4_t r1 = vbslq_f32(flag1, t1, v1); float32x4_t r1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value));
float32x4_t r2 = vbslq_f32(flag2, t2, v2); float32x4_t r2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value));
float32x4_t r3 = vbslq_f32(flag3, t3, v3); float32x4_t r3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value));
float32x4_t r4 = vbslq_f32(flag4, t4, v4); float32x4_t r4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value));
float32x4_t r5 = vbslq_f32(flag5, t5, v5); float32x4_t r5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value));
float32x4_t r6 = vbslq_f32(flag6, t6, v6); float32x4_t r6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value));
float32x4_t r7 = vbslq_f32(flag7, t7, v7); float32x4_t r7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value));
float32x4_t r8 = vbslq_f32(flag8, t8, v8); float32x4_t r8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value));
vst1q_f32(output_ptr + c_offset, r1); vst1q_f32(out_plane_ptr + channel_index + 0 * channel_num, r1);
vst1q_f32(output_ptr + c_offset + channel_num, r2); vst1q_f32(out_plane_ptr + channel_index + 1 * channel_num, r2);
vst1q_f32(output_ptr + c_offset + 2 * channel_num, r3); vst1q_f32(out_plane_ptr + channel_index + 2 * channel_num, r3);
vst1q_f32(output_ptr + c_offset + 3 * channel_num, r4); vst1q_f32(out_plane_ptr + channel_index + 3 * channel_num, r4);
vst1q_f32(output_ptr + c_offset + 4 * channel_num, r5); vst1q_f32(out_plane_ptr + channel_index + 4 * channel_num, r5);
vst1q_f32(output_ptr + c_offset + 5 * channel_num, r6); vst1q_f32(out_plane_ptr + channel_index + 5 * channel_num, r6);
vst1q_f32(output_ptr + c_offset + 6 * channel_num, r7); vst1q_f32(out_plane_ptr + channel_index + 6 * channel_num, r7);
vst1q_f32(output_ptr + c_offset + 7 * channel_num, r8); vst1q_f32(out_plane_ptr + channel_index + 7 * channel_num, r8);
} // c4 -1 loop
#else
for (int i = 0; i < TILE_NUM; ++i) {
int tile_offset = i * channel_num;
for (int k = 0; k < c4; ++k) {
int c4_offset = tile_offset + k * C4NUM;
int slope_offset = k * C4NUM;
for (int l = 0; l < C4NUM; ++l) {
const float in_data = input_ptr[c4_offset + l];
output_ptr[c4_offset + l] =
(in_data < 0 ? in_data : 0) * negetive_slope_value[slope_offset + l] + (in_data > 0 ? in_data : 0);
}
} }
} // c4 - 1 loop
#endif #endif
int c_s = c4 * C4NUM; for (; channel_index < channel_num; channel_index++) {
for (int m = 0; m < TILE_NUM; ++m) { float *in_c = in_plane_ptr + channel_index;
int offset = m * channel_num; float *out_c = out_plane_ptr + channel_index;
for (int k = c_s; k < channel_num; ++k) { for (int tile_i = 0; tile_i < TILE_NUM; tile_i++) {
int c4_offset = offset + k; float *in_tile = in_c + tile_i * channel_num;
const float in_data = input_ptr[c4_offset]; float *out_tile = out_c + tile_i * channel_num;
if (in_data >= 0) { const float in_data = in_tile[0];
output_ptr[c4_offset] = in_data; out_tile[0] = (in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0);
} else {
output_ptr[c4_offset] = in_data * negetive_slope_value[k];
} }
} }
} // res loop }
for (; plane_index < plane; plane_index++) {
float *in_plane_ptr = input + plane_index * channel_num;
float *out_plane_ptr = output + plane_index * channel_num;
for (int channel_index = 0; channel_index < channel_num; channel_index++) {
const float in_data = in_plane_ptr[channel_index];
out_plane_ptr[channel_index] =
(in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0);
}
} }
} }

@ -39,19 +39,39 @@ int PReluRun(void *cdata, int task_id) {
} }
} // namespace } // namespace
int PReluCPUKernel::Init() { return RET_OK; } int PReluCPUKernel::Init() {
if (in_tensors_[1]->ElementsNum() == 1) {
prelu_param_->channelShared = true;
} else {
prelu_param_->channelShared = false;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int PReluCPUKernel::DoExcute(int task_id) { int PReluCPUKernel::DoExcute(int task_id) {
if (prelu_param_->channelShared) { if (prelu_param_->channelShared) {
PReluShareChannel(input_data_, output_data_, prelu_param_, task_id); PReluShareChannel(input_data_, output_data_, prelu_param_, task_id);
} else { } else {
PRelu(input_data_, output_data_, prelu_param_, task_id); int res_plane = prelu_param_->input_num_ - task_id * prelu_param_->tile_block_;
int plane = MSMIN(prelu_param_->tile_block_, res_plane);
if (plane <= 0) {
return RET_OK;
}
float *in = input_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_;
float *out = output_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_;
PRelu(in, out, prelu_param_, plane);
} }
return RET_OK; return RET_OK;
} }
int PReluCPUKernel::ProcessInput() { int PReluCPUKernel::ReSize() {
// input tensor if (prelu_param_->channelShared) {
return RET_OK;
}
auto input_tensor = in_tensors_.at(0); auto input_tensor = in_tensors_.at(0);
auto in_shape = input_tensor->shape(); auto in_shape = input_tensor->shape();
auto n_dim = in_shape.size(); auto n_dim = in_shape.size();
@ -60,57 +80,36 @@ int PReluCPUKernel::ProcessInput() {
for (size_t i = 0; i < n_dim - 1; ++i) { for (size_t i = 0; i < n_dim - 1; ++i) {
input_plane *= in_shape.at(i); input_plane *= in_shape.at(i);
} }
int tile_block = UP_DIV(input_plane, TILE_NUM);
prelu_param_->input_num_ = input_tensor->ElementsNum(); prelu_param_->input_num_ = input_plane;
prelu_param_->tile_block_ = tile_block; prelu_param_->tile_block_ = UP_DIV(UP_DIV(input_plane, TILE_NUM), op_parameter_->thread_num_) * TILE_NUM;
prelu_param_->channel_num_ = channel_num; prelu_param_->channel_num_ = channel_num;
input_data_ =
reinterpret_cast<float *>(context_->allocator->Malloc(tile_block * TILE_NUM * channel_num * sizeof(float)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR;
}
memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float));
return RET_OK; return RET_OK;
} }
int PReluCPUKernel::ProcessShareChannelInput() { int PReluCPUKernel::ProcessShareChannelInput() {
// input tensor
auto input_tensor = in_tensors_.at(0); auto input_tensor = in_tensors_.at(0);
prelu_param_->input_num_ = input_tensor->ElementsNum(); prelu_param_->input_num_ = input_tensor->ElementsNum();
int tile = 32;
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 64); tile = 64;
input_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * 64 * sizeof(float))); #endif
if (input_data_ == nullptr) { prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, tile);
MS_LOG(ERROR) << "malloc input_data_ failed."; input_data_ =
return RET_ERROR; reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * tile * sizeof(float)));
}
memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float));
#elif ENABLE_ARM32
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 32);
input_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * 32 * sizeof(float)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR;
}
memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float));
#else
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 32);
input_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * 32 * sizeof(float)));
if (input_data_ == nullptr) { if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed."; MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR; return RET_ERROR;
} }
memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float)); memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float));
#endif
return RET_OK; return RET_OK;
} }
int PReluCPUKernel::Run() { int PReluCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() >= 2); MS_ASSERT(in_tensors_.size() >= 2);
auto input_tensor = in_tensors_[0]; auto input_tensor = in_tensors_[0];
ori_input_ = reinterpret_cast<float *>(input_tensor->MutableData()); ori_input_ = reinterpret_cast<float *>(input_tensor->data_c());
output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
MS_ASSERT(ori_input_); MS_ASSERT(ori_input_);
MS_ASSERT(output_data_); MS_ASSERT(output_data_);
if (prelu_param_->channelShared) { if (prelu_param_->channelShared) {
@ -120,16 +119,12 @@ int PReluCPUKernel::Run() {
return ret; return ret;
} }
} else { } else {
auto ret = ProcessInput(); input_data_ = ori_input_;
if (ret != RET_OK) {
MS_LOG(ERROR) << "Process failed.";
return ret;
}
} }
// negative slope tensor // negative slope tensor
auto negative_slope_tensor = in_tensors_.at(1); auto negative_slope_tensor = in_tensors_.at(1);
prelu_param_->slope_ = reinterpret_cast<float *>(negative_slope_tensor->MutableData()); prelu_param_->slope_ = reinterpret_cast<float *>(negative_slope_tensor->data_c());
auto ret = ParallelLaunch(this->context_->thread_pool_, PReluRun, this, prelu_param_->op_parameter_.thread_num_); auto ret = ParallelLaunch(this->context_->thread_pool_, PReluRun, this, prelu_param_->op_parameter_.thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
@ -138,8 +133,10 @@ int PReluCPUKernel::Run() {
return RET_ERROR; return RET_ERROR;
} }
if (prelu_param_->channelShared) {
memcpy(output_data_, input_data_, prelu_param_->input_num_ * sizeof(float)); memcpy(output_data_, input_data_, prelu_param_->input_num_ * sizeof(float));
context_->allocator->Free(input_data_); context_->allocator->Free(input_data_);
}
return RET_OK; return RET_OK;
} }

@ -33,11 +33,10 @@ class PReluCPUKernel : public LiteKernel {
~PReluCPUKernel() = default; ~PReluCPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
int DoExcute(int task_id); int DoExcute(int task_id);
int ProcessShareChannelInput(); int ProcessShareChannelInput();
int ProcessInput();
private: private:
PReluParameter *prelu_param_; PReluParameter *prelu_param_;

Loading…
Cancel
Save