parent
4278813832
commit
3bae799513
@ -0,0 +1,134 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp32/gru_fp32.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/fp32/lstm_fp32.h"
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
|
||||
void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) {
|
||||
int gate_offest = 0;
|
||||
for (int l = 0; l < 3; l++) {
|
||||
int batch_offest = gate_offest;
|
||||
int bias_offest = l * gru_parm->hidden_size_;
|
||||
for (int b = 0; b < gru_parm->batch_; b++) {
|
||||
memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float));
|
||||
batch_offest += gru_parm->hidden_size_;
|
||||
}
|
||||
gate_offest += gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
}
|
||||
}
|
||||
|
||||
void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight,
|
||||
const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight,
|
||||
const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer,
|
||||
const GruParameter *gru_parm) {
|
||||
InitGruGate(gate_buffer, bias, gru_parm);
|
||||
|
||||
float *update_gate = gate_buffer;
|
||||
float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2;
|
||||
|
||||
// input * weight
|
||||
MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
|
||||
MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
|
||||
MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
|
||||
|
||||
// state * weight
|
||||
MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->hidden_size_);
|
||||
MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->hidden_size_);
|
||||
|
||||
// update reset_gate
|
||||
Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate);
|
||||
|
||||
// update update_gate
|
||||
Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate);
|
||||
|
||||
ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->hidden_size_);
|
||||
|
||||
Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer);
|
||||
|
||||
ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
|
||||
ArithmeticParameter parameter;
|
||||
parameter.in_elements_num0_ = 1;
|
||||
parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
const float one = 1.0f;
|
||||
ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter);
|
||||
|
||||
ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
|
||||
memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float));
|
||||
}
|
||||
|
||||
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
|
||||
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) {
|
||||
// forward
|
||||
const float *input_update_weight = weight_g;
|
||||
const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_;
|
||||
const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2;
|
||||
|
||||
const float *state_update_weight = weight_r;
|
||||
const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_;
|
||||
const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2;
|
||||
|
||||
for (int t = 0; t < check_seq_len; t++) {
|
||||
const float *input_ptr = input + t * gru_parm->input_step_;
|
||||
float *output_ptr = output + t * gru_parm->output_step_;
|
||||
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight,
|
||||
state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm);
|
||||
}
|
||||
// zero out extra fw outputs
|
||||
for (int t = check_seq_len; t < gru_parm->seq_len_; t++) {
|
||||
float *output_ptr = output + t * gru_parm->output_step_;
|
||||
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
|
||||
output_ptr[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// backward
|
||||
if (gru_parm->bidirectional_) {
|
||||
input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3;
|
||||
input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4;
|
||||
input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5;
|
||||
|
||||
state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3;
|
||||
state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4;
|
||||
state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5;
|
||||
|
||||
float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
const float *backward_bias = bias + 3 * gru_parm->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
for (int t = check_seq_len - 1; t >= 0; t--) {
|
||||
const float *input_ptr = input + t * gru_parm->input_step_;
|
||||
float *output_ptr = backward_output + t * gru_parm->output_step_;
|
||||
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight,
|
||||
state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state,
|
||||
gate_buffer, gru_parm);
|
||||
}
|
||||
// zero out extra bw outputs
|
||||
for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) {
|
||||
float *output_ptr = backward_output + t * gru_parm->output_step_;
|
||||
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
|
||||
output_ptr[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct GruParameter {
|
||||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
// shape correlative
|
||||
int input_size_;
|
||||
int hidden_size_; // output_size
|
||||
int seq_len_;
|
||||
int batch_;
|
||||
// other parameter
|
||||
int input_step_;
|
||||
int output_step_;
|
||||
bool bidirectional_;
|
||||
} GruParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
|
||||
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_
|
@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/gru.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
bool Gru::GetBidirection() const { return this->primitive_->value.AsGru()->bidirection; }
|
||||
|
||||
void Gru::SetBidirection(bool bidirection) { this->primitive_->value.AsGru()->bidirection = bidirection; }
|
||||
|
||||
#else
|
||||
|
||||
bool Gru::GetBidirection() const { return this->primitive_->value_as_Gru()->bidirection(); }
|
||||
int Gru::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Gru();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Gru return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateGru(*fbb, attr->bidirection());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gru, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *GruCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Gru>(primitive); }
|
||||
Registry GruRegistry(schema::PrimitiveType_Gru, GruCreator);
|
||||
#endif
|
||||
|
||||
const int kGruInputNum = 5;
|
||||
const int kGruInputWithSeqLenNum = 6;
|
||||
const int kGruOutputNum = 2;
|
||||
int Gru::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if ((inputs_.size() != kGruInputNum && inputs_.size() != kGruInputWithSeqLenNum) ||
|
||||
outputs_.size() != kGruOutputNum) {
|
||||
MS_LOG(ERROR) << "OpGru inputs or outputs size error.";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto weight_gate = inputs_.at(1);
|
||||
MS_ASSERT(weight_gate != nullptr);
|
||||
auto weight_recurrence = inputs_.at(2);
|
||||
MS_ASSERT(weight_recurrence != nullptr);
|
||||
auto bias = inputs_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
for (int i = 0; i < kGruOutputNum; i++) {
|
||||
outputs_.at(i)->set_data_type(input->data_type());
|
||||
outputs_.at(i)->set_format(input->format());
|
||||
}
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
|
||||
auto in_shape = input->shape(); // seq_len, batch, input_size
|
||||
auto w_gate_shape = weight_gate->shape(); // num_direction, hidden_size * 3, input_size
|
||||
auto w_recu_shape = weight_recurrence->shape(); // num_direction, hidden_size * 3, hidden_size
|
||||
auto bias_shape = bias->shape(); // num_direction, hidden_size * 6
|
||||
if (in_shape.size() != 3 || w_gate_shape.size() != 3 || w_recu_shape.size() != 3) {
|
||||
MS_LOG(ERROR) << "OpGru input dims should be 3.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) {
|
||||
MS_LOG(ERROR) << "OpGru w_gate, w_recu and bias hidden size not match.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inputs_.size() == kGruInputWithSeqLenNum) {
|
||||
auto seq_len_shape = inputs_.at(5)->shape();
|
||||
if (seq_len_shape[0] > 1) {
|
||||
MS_LOG(WARNING) << "OpGru with batch_size > 1 only support all same sequence_len now.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (seq_len_shape.size() != 1 && seq_len_shape[0] != in_shape[1]) {
|
||||
MS_LOG(ERROR) << "OpGru sequence_len shape[0] and batch_size not match.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int hidden_size = w_gate_shape[1] / 3;
|
||||
// set output
|
||||
std::vector<int> out_shape(in_shape);
|
||||
out_shape[2] = hidden_size;
|
||||
if (GetBidirection()) {
|
||||
out_shape.insert(out_shape.begin() + 1, 2);
|
||||
} else {
|
||||
out_shape.insert(out_shape.begin() + 1, 1);
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
// set hidden state
|
||||
std::vector<int> state_shape(in_shape);
|
||||
state_shape[0] = GetBidirection() ? 2 : 1;
|
||||
state_shape[2] = hidden_size;
|
||||
outputs_[1]->set_shape(state_shape);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_GRU_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_GRU_H_
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
/*
|
||||
* gru with linear_before_reset = 0
|
||||
*/
|
||||
class Gru : public PrimitiveC {
|
||||
public:
|
||||
Gru() = default;
|
||||
~Gru() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Gru, PrimitiveC);
|
||||
explicit Gru(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetBidirection(bool bidirection);
|
||||
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
bool GetBidirection() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_GRU_H_
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/gru.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32/gru_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateGruParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
GruParameter *gru_param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter)));
|
||||
if (gru_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc GruParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(gru_param, 0, sizeof(GruParameter));
|
||||
gru_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::Gru *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
if (param == nullptr) {
|
||||
free(gru_param);
|
||||
MS_LOG(ERROR) << "get Gru param nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
gru_param->bidirectional_ = param->GetBidirection();
|
||||
return reinterpret_cast<OpParameter *>(gru_param);
|
||||
}
|
||||
Registry GruParameterRegistry(schema::PrimitiveType_Gru, PopulateGruParameter);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,165 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp32/gru_fp32.h"
|
||||
#include <vector>
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Gru;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void GruCPUKernel::FreeTmpBuffer() {
|
||||
if (gate_buffer_ != nullptr) {
|
||||
free(gate_buffer_);
|
||||
gate_buffer_ = nullptr;
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
weight_g_ptr_ = nullptr;
|
||||
weight_r_ptr_ = nullptr;
|
||||
}
|
||||
|
||||
int GruCPUKernel::InitParam() {
|
||||
auto input = in_tensors_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
std::vector<int> in_shape = input->shape();
|
||||
gru_parm_->seq_len_ = in_shape.at(0);
|
||||
gru_parm_->batch_ = in_shape.at(1);
|
||||
gru_parm_->input_size_ = in_shape.at(2);
|
||||
|
||||
auto weight_g = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_g != nullptr);
|
||||
std::vector<int> w_shape = weight_g->shape();
|
||||
gru_parm_->hidden_size_ = w_shape.at(1) / 3;
|
||||
|
||||
gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_;
|
||||
gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_
|
||||
: gru_parm_->batch_ * gru_parm_->hidden_size_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruCPUKernel::InitBuffer() {
|
||||
gate_buffer_ = reinterpret_cast<float *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruCPUKernel::InitWeightBias() {
|
||||
auto weight_gate = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_gate != nullptr);
|
||||
weight_g_ptr_ = reinterpret_cast<float *>(weight_gate->data_c());
|
||||
|
||||
auto weight_recu = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_recu != nullptr);
|
||||
weight_r_ptr_ = reinterpret_cast<float *>(weight_recu->data_c());
|
||||
|
||||
int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_;
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
|
||||
const int state_bias_offset = 3 * gru_parm_->hidden_size_;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||
}
|
||||
if (gru_parm_->bidirectional_) {
|
||||
bias_data += 3 * gru_parm_->hidden_size_ * 2;
|
||||
auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int GruCPUKernel::ReSize() {
|
||||
FreeTmpBuffer();
|
||||
auto ret = InitParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel InitParam error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel InitBuffer error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruCPUKernel::Run() {
|
||||
auto input = in_tensors_.at(kInputIndex);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto hidden_state = in_tensors_.at(4);
|
||||
MS_ASSERT(hidden_state != nullptr);
|
||||
auto output = out_tensors_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
auto input_ptr = reinterpret_cast<float *>(input->data_c());
|
||||
MS_ASSERT(input_ptr);
|
||||
auto output_ptr = reinterpret_cast<float *>(output->MutableData());
|
||||
MS_ASSERT(output_ptr);
|
||||
auto output_hidden_state = out_tensors_[1];
|
||||
memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
|
||||
int check_seq_len = gru_parm_->seq_len_;
|
||||
if (in_tensors_.size() == 6) {
|
||||
auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c());
|
||||
if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) {
|
||||
MS_LOG(ERROR) << "different batch seq_len is currently not supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0]));
|
||||
}
|
||||
|
||||
MS_ASSERT(weight_g_ptr_);
|
||||
MS_ASSERT(weight_r_ptr_);
|
||||
MS_ASSERT(bias_ptr_);
|
||||
MS_ASSERT(gate_buffer_);
|
||||
Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float *>(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gru, LiteKernelCreator<GruCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,52 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/fp32/gru_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class GruCPUKernel : public LiteKernel {
|
||||
public:
|
||||
GruCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_);
|
||||
}
|
||||
|
||||
~GruCPUKernel() override { FreeTmpBuffer(); }
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
int InitParam();
|
||||
int InitBuffer();
|
||||
int InitWeightBias();
|
||||
|
||||
float *gate_buffer_ = nullptr;
|
||||
const float *weight_g_ptr_ = nullptr;
|
||||
const float *weight_r_ptr_ = nullptr;
|
||||
float *bias_ptr_ = nullptr;
|
||||
GruParameter *gru_parm_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_
|
@ -0,0 +1 @@
|
||||
decoder_step_201217.pb 5
|
@ -0,0 +1,61 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_select_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFSelectParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF SelectParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::SwitchT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Switch;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfSelectParser("Select", new TFSelectParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue