diff --git a/mindspore/lite/nnacl/fp16/gru_fp16.c b/mindspore/lite/nnacl/fp16/gru_fp16.c new file mode 100644 index 0000000000..028a4c3263 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/gru_fp16.c @@ -0,0 +1,139 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/gru_fp16.h" +#include +#include "nnacl/fp16/lstm_fp16.h" +#include "nnacl/fp16/activation_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" + +void InitGruGateFp16(float16_t *gate_buffer, const float16_t *bias, const GruParameter *gru_parm) { + int gate_offest = 0; + for (int l = 0; l < 3; l++) { + int batch_offest = gate_offest; + int bias_offest = l * gru_parm->hidden_size_; + for (int b = 0; b < gru_parm->batch_; b++) { + memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float16_t)); + batch_offest += gru_parm->hidden_size_; + } + gate_offest += gru_parm->batch_ * gru_parm->hidden_size_; + } +} + +void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_reset_weight, + const float16_t *input_update_weight, const float16_t *input_hidden_weight, + const float16_t *state_reset_weight, const float16_t *state_update_weight, + const float16_t *state_hidden_weight, const float16_t *bias, float16_t *hidden_state, + float16_t *gate_buffer, const GruParameter *gru_parm) { + InitGruGateFp16(gate_buffer, bias, gru_parm); + + float16_t *update_gate = gate_buffer; + float16_t *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_; + float16_t *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2; + + // input * weight + MatMulAccFp16(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); + MatMulAccFp16(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->input_size_); + MatMulAccFp16(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->input_size_); + + // state * weight + MatMulAccFp16(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->hidden_size_); + MatMulAccFp16(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->hidden_size_); + + // update reset_gate + SigmoidFp16(reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); + + // update update_gate + SigmoidFp16(update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_); + + ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); + MatMulAccFp16(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, + gru_parm->hidden_size_); + + TanhFp16(hidden_buffer, hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_); + + ElementMulFp16(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + + ArithmeticParameter parameter; + parameter.in_elements_num0_ = 1; + parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_; + float16_t one = 1.0f; + ElementOptSubFp16(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter); + + ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + + memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float16_t)); +} + +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len, + const GruParameter *gru_parm) { + // forward + const float16_t *input_update_weight = weight_g; + const float16_t *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_; + const float16_t *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2; + + const float16_t *state_update_weight = weight_r; + const float16_t *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_; + const float16_t *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2; + + for (int t = 0; t < check_seq_len; t++) { + const float16_t *input_ptr = input + t * gru_parm->input_step_; + float16_t *output_ptr = output + t * gru_parm->output_step_; + GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, + state_reset_weight, state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, + gru_parm); + } + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_parm->seq_len_; t++) { + float16_t *output_ptr = output + t * gru_parm->output_step_; + for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_parm->bidirectional_) { + input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3; + input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4; + input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5; + + state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3; + state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4; + state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5; + + float16_t *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_; + const float16_t *backward_bias = bias + 3 * gru_parm->hidden_size_; + float16_t *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_; + for (int t = check_seq_len - 1; t >= 0; t--) { + const float16_t *input_ptr = input + t * gru_parm->input_step_; + float16_t *output_ptr = backward_output + t * gru_parm->output_step_; + GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, + state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, + backward_hidden_state, gate_buffer, gru_parm); + } + // zero out extra bw outputs + for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) { + float16_t *output_ptr = backward_output + t * gru_parm->output_step_; + for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore/lite/nnacl/fp16/gru_fp16.h b/mindspore/lite/nnacl/fp16/gru_fp16.h new file mode 100644 index 0000000000..4d23cc0e96 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/gru_fp16.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP16_GRU_H_ +#define MINDSPORE_LITE_NNACL_FP16_GRU_H_ +#include "nnacl/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len, + const GruParameter *gru_parm); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_GRU_H_ diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.h b/mindspore/lite/nnacl/fp32/gru_fp32.h index e247783501..a9fc4d2555 100644 --- a/mindspore/lite/nnacl/fp32/gru_fp32.h +++ b/mindspore/lite/nnacl/fp32/gru_fp32.h @@ -15,21 +15,7 @@ */ #ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ #define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ -#include "nnacl/op_base.h" - -typedef struct GruParameter { - // Primitive parameter - OpParameter op_parameter_; - // shape correlative - int input_size_; - int hidden_size_; // output_size - int seq_len_; - int batch_; - // other parameter - int input_step_; - int output_step_; - bool bidirectional_; -} GruParameter; +#include "nnacl/gru_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/gru_parameter.h b/mindspore/lite/nnacl/gru_parameter.h new file mode 100644 index 0000000000..cbd85e1c3f --- /dev/null +++ b/mindspore/lite/nnacl/gru_parameter.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ +#define MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct GruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int input_step_; + int output_step_; + bool bidirectional_; +} GruParameter; + +#endif // MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc new file mode 100644 index 0000000000..6004013897 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc @@ -0,0 +1,189 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp16/gru_fp16.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "nnacl/fp16/gru_fp16.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Gru; + +namespace mindspore::kernel { +void GruFp16CPUKernel::FreeTmpBuffer() { + if (gate_buffer_ != nullptr) { + free(gate_buffer_); + gate_buffer_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } + if (weight_g_ptr_ != nullptr) { + free(weight_g_ptr_); + weight_g_ptr_ = nullptr; + } + if (weight_r_ptr_ != nullptr) { + free(weight_r_ptr_); + weight_r_ptr_ = nullptr; + } +} + +int GruFp16CPUKernel::InitParam() { + auto input = in_tensors_.front(); + MS_ASSERT(input != nullptr); + std::vector in_shape = input->shape(); + gru_parm_->seq_len_ = in_shape.at(0); + gru_parm_->batch_ = in_shape.at(1); + gru_parm_->input_size_ = in_shape.at(2); + + auto weight_g = in_tensors_.at(1); + MS_ASSERT(weight_g != nullptr); + std::vector w_shape = weight_g->shape(); + gru_parm_->hidden_size_ = w_shape.at(1) / 3; + + gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_; + gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_ + : gru_parm_->batch_ * gru_parm_->hidden_size_; + return RET_OK; +} + +int GruFp16CPUKernel::InitBuffer() { + gate_buffer_ = + reinterpret_cast(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float16_t))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int GruFp16CPUKernel::InitWeightBias() { + auto weight_gate = in_tensors_.at(1); + MS_ASSERT(weight_gate != nullptr); + weight_g_ptr_ = reinterpret_cast(malloc(weight_gate->ElementsNum() * sizeof(float16_t))); + if (weight_g_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error."; + return RET_ERROR; + } + auto weight_g_data = reinterpret_cast(weight_gate->data_c()); + for (size_t i = 0; i < weight_gate->ElementsNum(); i++) { + weight_g_ptr_[i] = (float16_t)weight_g_data[i]; + } + + auto weight_recu = in_tensors_.at(2); + MS_ASSERT(weight_recu != nullptr); + weight_r_ptr_ = reinterpret_cast(malloc(weight_recu->ElementsNum() * sizeof(float16_t))); + if (weight_r_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error."; + return RET_ERROR; + } + auto weight_r_data = reinterpret_cast(weight_recu->data_c()); + for (size_t i = 0; i < weight_recu->ElementsNum(); i++) { + weight_r_ptr_[i] = (float16_t)weight_r_data[i]; + } + + int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; + bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float16_t))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + + auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); + const int state_bias_offset = 3 * gru_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); + } + if (gru_parm_->bidirectional_) { + bias_data += 3 * gru_parm_->hidden_size_ * 2; + auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); + } + } + return RET_OK; +} + +int GruFp16CPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int GruFp16CPUKernel::ReSize() { + FreeTmpBuffer(); + auto ret = InitParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruFp16CPUKernel InitParam error."; + return RET_ERROR; + } + + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruFp16CPUKernel InitWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruFp16CPUKernel InitBuffer error."; + FreeTmpBuffer(); + return RET_ERROR; + } + return RET_OK; +} + +int GruFp16CPUKernel::Run() { + auto input = in_tensors_.at(kInputIndex); + MS_ASSERT(input != nullptr); + auto hidden_state = in_tensors_.at(4); + MS_ASSERT(hidden_state != nullptr); + auto output = out_tensors_.at(0); + MS_ASSERT(output != nullptr); + auto input_ptr = reinterpret_cast(input->data_c()); + MS_ASSERT(input_ptr); + auto output_ptr = reinterpret_cast(output->data_c()); + MS_ASSERT(output_ptr); + auto output_hidden_state = out_tensors_[1]; + memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t)); + int check_seq_len = gru_parm_->seq_len_; + if (in_tensors_.size() == 6) { + auto seq_len = reinterpret_cast(in_tensors_.at(5)->data_c()); + if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) { + MS_LOG(ERROR) << "different batch seq_len is currently not supported"; + return RET_ERROR; + } + check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); + } + + MS_ASSERT(weight_g_ptr_ != nullptr); + MS_ASSERT(weight_r_ptr_ != nullptr); + MS_ASSERT(bias_ptr_ != nullptr); + MS_ASSERT(gate_buffer_ != nullptr); + GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, + reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_); + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Gru, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h new file mode 100644 index 0000000000..ca47089379 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GRU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GRU_H_ +#include +#include "src/lite_kernel.h" +#include "nnacl/gru_parameter.h" + +namespace mindspore::kernel { +class GruFp16CPUKernel : public LiteKernel { + public: + GruFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + gru_parm_ = reinterpret_cast(op_parameter_); + } + + ~GruFp16CPUKernel() override { FreeTmpBuffer(); } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + void FreeTmpBuffer(); + int InitParam(); + int InitBuffer(); + int InitWeightBias(); + + float16_t *gate_buffer_ = nullptr; + float16_t *weight_g_ptr_ = nullptr; + float16_t *weight_r_ptr_ = nullptr; + float16_t *bias_ptr_ = nullptr; + GruParameter *gru_parm_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GRU_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc index cd70788e54..b623275215 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc @@ -18,6 +18,7 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" +#include "nnacl/fp32/gru_fp32.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -70,11 +71,21 @@ int GruCPUKernel::InitBuffer() { int GruCPUKernel::InitWeightBias() { auto weight_gate = in_tensors_.at(1); MS_ASSERT(weight_gate != nullptr); - weight_g_ptr_ = reinterpret_cast(weight_gate->data_c()); + weight_g_ptr_ = reinterpret_cast(malloc(weight_gate->ElementsNum() * sizeof(float))); + if (weight_g_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error."; + return RET_ERROR; + } + memcpy(weight_g_ptr_, weight_gate->data_c(), weight_gate->ElementsNum() * sizeof(float)); auto weight_recu = in_tensors_.at(2); MS_ASSERT(weight_recu != nullptr); - weight_r_ptr_ = reinterpret_cast(weight_recu->data_c()); + weight_r_ptr_ = reinterpret_cast(malloc(weight_recu->ElementsNum() * sizeof(float))); + if (weight_r_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error."; + return RET_ERROR; + } + memcpy(weight_r_ptr_, weight_recu->data_c(), weight_recu->ElementsNum() * sizeof(float)); int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h index 720323d520..ee661d9d25 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ #include #include "src/lite_kernel.h" -#include "nnacl/fp32/gru_fp32.h" +#include "nnacl/gru_parameter.h" namespace mindspore::kernel { class GruCPUKernel : public LiteKernel { @@ -42,8 +42,8 @@ class GruCPUKernel : public LiteKernel { int InitWeightBias(); float *gate_buffer_ = nullptr; - const float *weight_g_ptr_ = nullptr; - const float *weight_r_ptr_ = nullptr; + float *weight_g_ptr_ = nullptr; + float *weight_r_ptr_ = nullptr; float *bias_ptr_ = nullptr; GruParameter *gru_parm_ = nullptr; };