From 23757309f725bb03da38c17cb761ac1e8aab2f1d Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Tue, 13 Oct 2020 10:20:11 +0800 Subject: [PATCH] [MSLITE][Develop] optimize arm fp32 cpu op lstm: add neon --- mindspore/lite/nnacl/fp32/lstm.c | 36 ++++++++++++++++++--- mindspore/lite/src/ops/constant_of_shape.cc | 6 +++- mindspore/lite/src/ops/lstm.cc | 2 +- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/lstm.c b/mindspore/lite/nnacl/fp32/lstm.c index 8742b5c7b8..febcd70d1a 100644 --- a/mindspore/lite/nnacl/fp32/lstm.c +++ b/mindspore/lite/nnacl/fp32/lstm.c @@ -37,16 +37,44 @@ void MatMulAcc(float *output, const float *input, const float *weight, int rows, for (int r = 0; r < rows; r++) { for (int c = 0; c < cols; c++) { float res = 0; - for (int i = 0; i < inner_size; i++) { - res += input[r * inner_size + i] * weight[c * inner_size + i]; + const float *input_col = input + r * inner_size; + const float *weight_col = weight + c * inner_size; + int index = 0; +#ifdef ENABLE_ARM + float32x4_t out = vdupq_n_f32(0.0f); + for (; index < inner_size - 4; index += 4) { + float32x4_t in_0 = vld1q_f32(input_col + index); + float32x4_t in_1 = vld1q_f32(weight_col + index); + out = vmlaq_f32(out, in_1, in_0); + } +#ifdef ENABLE_ARM64 + res += vaddvq_f32(out); +#else + float32x2_t add2 = vadd_f32(vget_low_f32(out), vget_high_f32(out)); + float32x2_t add4 = vpadd_f32(add2, add2); + res += vget_lane_f32(add4, 0); +#endif +#endif + for (; index < inner_size; index++) { + res += input_col[index] * weight_col[index]; } output[r * cols + c] += res; } } } -void ElementMulAcc(float *input0, float *input1, float *output, int element_size) { - for (int index = 0; index < element_size; index++) { +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) { + int index = 0; +#ifdef ENABLE_ARM + for (; index < element_size - 4; index += 4) { + float32x4_t in_0 = vld1q_f32(input0 + index); + float32x4_t in_1 = vld1q_f32(input1 + index); + float32x4_t out = vld1q_f32(output + index); + out = vmlaq_f32(out, in_1, in_0); + vst1q_f32(output + index, out); + } +#endif + for (; index < element_size; index++) { output[index] += input0[index] * input1[index]; } } diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index b80852a9c2..b6e6018394 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -67,7 +67,11 @@ int ConstantOfShape::InferShape(std::vector inputs_, std::vector(in_tensor->MutableData()); + auto in_data = reinterpret_cast(in_tensor->data_c()); + if (in_data == nullptr) { + MS_LOG(ERROR) << "Input data is nullptr"; + return RET_INFER_INVALID; + } int size = in_tensor->ElementsNum(); std::vector out_shape(size); for (int i = 0; i < size; ++i) { diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc index af9df6c1df..eec7b1b915 100644 --- a/mindspore/lite/src/ops/lstm.cc +++ b/mindspore/lite/src/ops/lstm.cc @@ -51,7 +51,7 @@ int Lstm::InferShape(std::vector inputs_, std::vector output } auto input = inputs_.front(); MS_ASSERT(input != nullptr); - auto weight_i = inputs_.front(); + auto weight_i = inputs_[1]; MS_ASSERT(input0 != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr);