lstm add smooth & add gru op & add 3 lstm fusion pass

pull/10695/head
wangzhe 5 years ago
parent 4278813832
commit 3bae799513

@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/gru_fp32.h"
#include <string.h>
#include "nnacl/fp32/lstm_fp32.h"
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"
void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) {
int gate_offest = 0;
for (int l = 0; l < 3; l++) {
int batch_offest = gate_offest;
int bias_offest = l * gru_parm->hidden_size_;
for (int b = 0; b < gru_parm->batch_; b++) {
memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float));
batch_offest += gru_parm->hidden_size_;
}
gate_offest += gru_parm->batch_ * gru_parm->hidden_size_;
}
}
void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight,
const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight,
const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer,
const GruParameter *gru_parm) {
InitGruGate(gate_buffer, bias, gru_parm);
float *update_gate = gate_buffer;
float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_;
float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2;
// input * weight
MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
// state * weight
MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);
MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);
// update reset_gate
Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate);
// update update_gate
Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate);
ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_);
MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);
Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer);
ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
ArithmeticParameter parameter;
parameter.in_elements_num0_ = 1;
parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_;
const float one = 1.0f;
ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, &parameter);
ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float));
}
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) {
// forward
const float *input_update_weight = weight_g;
const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_;
const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2;
const float *state_update_weight = weight_r;
const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_;
const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2;
for (int t = 0; t < check_seq_len; t++) {
const float *input_ptr = input + t * gru_parm->input_step_;
float *output_ptr = output + t * gru_parm->output_step_;
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight,
state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm);
}
// zero out extra fw outputs
for (int t = check_seq_len; t < gru_parm->seq_len_; t++) {
float *output_ptr = output + t * gru_parm->output_step_;
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
output_ptr[i] = 0.0f;
}
}
// backward
if (gru_parm->bidirectional_) {
input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3;
input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4;
input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5;
state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3;
state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4;
state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5;
float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_;
const float *backward_bias = bias + 3 * gru_parm->hidden_size_;
float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_;
for (int t = check_seq_len - 1; t >= 0; t--) {
const float *input_ptr = input + t * gru_parm->input_step_;
float *output_ptr = backward_output + t * gru_parm->output_step_;
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight,
state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state,
gate_buffer, gru_parm);
}
// zero out extra bw outputs
for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) {
float *output_ptr = backward_output + t * gru_parm->output_step_;
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
output_ptr[i] = 0.0f;
}
}
}
}

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_
#define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_
#include "nnacl/op_base.h"
typedef struct GruParameter {
// Primitive parameter
OpParameter op_parameter_;
// shape correlative
int input_size_;
int hidden_size_; // output_size
int seq_len_;
int batch_;
// other parameter
int input_step_;
int output_step_;
bool bidirectional_;
} GruParameter;
#ifdef __cplusplus
extern "C" {
#endif
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_

@ -16,6 +16,7 @@
#include "nnacl/fp32/lstm_fp32.h"
#include <string.h>
#include <float.h>
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"
@ -79,21 +80,63 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int
}
}
int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 4; index += C4NUM) {
float32x4_t vin0 = vld1q_f32(input0 + index);
float32x4_t vout = vld1q_f32(output + index);
vout = vmlaq_n_f32(vout, vin0, input1);
vst1q_f32(output + index, vout);
}
#endif
for (; index < element_size; index++) {
output[index] += input0[index] * input1;
}
return NNACL_OK;
}
void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate,
int batch, int hidden_size) {
float *state_buffer, int batch, int hidden_size, const float smooth) {
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state
memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float));
ArithmeticParameter parameter;
parameter.in_elements_num0_ = batch * hidden_size;
parameter.in_elements_num1_ = 1;
ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, &parameter);
}
ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size);
ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size);
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state
ElementOptMulAcc(cell_state, 1 - smooth, state_buffer, batch * hidden_size);
}
}
void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, int batch, int hidden_size) {
void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, float *state_buffer_in,
int batch, int hidden_size, const float smooth) {
float *state_buffer = state_buffer_in + batch * hidden_size;
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float));
ArithmeticParameter parameter;
parameter.in_elements_num0_ = batch * hidden_size;
parameter.in_elements_num1_ = 1;
ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, &parameter);
}
Tanh(cell_state, batch * hidden_size, hidden_state);
ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size);
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
ElementOptMulAcc(hidden_state, 1 - smooth, state_buffer, batch * hidden_size);
}
}
void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight,
const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight,
const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm) {
InitGate(gate_buffer, bias, lstm_parm);
@ -129,17 +172,26 @@ void LstmStepUnit(float *output, const float *input, const float *input_input_we
// update cell_gate
Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate);
// update cell state
UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_);
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->smooth_);
// update output_gate
Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate);
// update output
UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_);
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->smooth_);
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
}
}
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm) {
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm) {
// forward
const float *input_input_weight = weight_i;
const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
@ -156,7 +208,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float
float *output_ptr = output + t * lstm_parm->output_step_;
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight,
state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state,
cell_state, gate_buffer, lstm_parm);
cell_state, gate_buffer, state_buffer, lstm_parm);
}
// backward
@ -180,7 +232,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float
float *output_ptr = backward_output + t * lstm_parm->output_step_;
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight,
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm);
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm);
}
}
}

@ -31,13 +31,24 @@ typedef struct LstmParameter {
int input_step_;
int output_step_;
bool bidirectional_;
// smooth factor for hidden/cell state calculation:
// output_hidden = old_hidden * smooth + new_hidden * (1 - smooth)
// output_cell = old_cell * smooth + new_cell * (1 - smooth)
float smooth_;
} LstmParameter;
#ifdef __cplusplus
extern "C" {
#endif
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size);
void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size);
int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, const LstmParameter *lstm_parm);
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm);
#ifdef __cplusplus
}
#endif

@ -262,6 +262,7 @@ union PrimitiveType {
Merge,
Mod,
GeLU,
Gru,
}
enum QuantType: int {

@ -1005,6 +1005,11 @@ table OneHot {
table Lstm{
bidirection: bool = false;
smooth: float = 0.0;
}
table Gru{
bidirection: bool = false;
}
table PriorBox {

@ -0,0 +1,121 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/gru.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool Gru::GetBidirection() const { return this->primitive_->value.AsGru()->bidirection; }
void Gru::SetBidirection(bool bidirection) { this->primitive_->value.AsGru()->bidirection = bidirection; }
#else
bool Gru::GetBidirection() const { return this->primitive_->value_as_Gru()->bidirection(); }
int Gru::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Gru();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Gru return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateGru(*fbb, attr->bidirection());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gru, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *GruCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Gru>(primitive); }
Registry GruRegistry(schema::PrimitiveType_Gru, GruCreator);
#endif
const int kGruInputNum = 5;
const int kGruInputWithSeqLenNum = 6;
const int kGruOutputNum = 2;
int Gru::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
if ((inputs_.size() != kGruInputNum && inputs_.size() != kGruInputWithSeqLenNum) ||
outputs_.size() != kGruOutputNum) {
MS_LOG(ERROR) << "OpGru inputs or outputs size error.";
return RET_INPUT_TENSOR_ERROR;
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto weight_gate = inputs_.at(1);
MS_ASSERT(weight_gate != nullptr);
auto weight_recurrence = inputs_.at(2);
MS_ASSERT(weight_recurrence != nullptr);
auto bias = inputs_.at(3);
MS_ASSERT(bias != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
for (int i = 0; i < kGruOutputNum; i++) {
outputs_.at(i)->set_data_type(input->data_type());
outputs_.at(i)->set_format(input->format());
}
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto in_shape = input->shape(); // seq_len, batch, input_size
auto w_gate_shape = weight_gate->shape(); // num_direction, hidden_size * 3, input_size
auto w_recu_shape = weight_recurrence->shape(); // num_direction, hidden_size * 3, hidden_size
auto bias_shape = bias->shape(); // num_direction, hidden_size * 6
if (in_shape.size() != 3 || w_gate_shape.size() != 3 || w_recu_shape.size() != 3) {
MS_LOG(ERROR) << "OpGru input dims should be 3.";
return RET_ERROR;
}
if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) {
MS_LOG(ERROR) << "OpGru w_gate, w_recu and bias hidden size not match.";
return RET_ERROR;
}
if (inputs_.size() == kGruInputWithSeqLenNum) {
auto seq_len_shape = inputs_.at(5)->shape();
if (seq_len_shape[0] > 1) {
MS_LOG(WARNING) << "OpGru with batch_size > 1 only support all same sequence_len now.";
return RET_ERROR;
}
if (seq_len_shape.size() != 1 && seq_len_shape[0] != in_shape[1]) {
MS_LOG(ERROR) << "OpGru sequence_len shape[0] and batch_size not match.";
return RET_ERROR;
}
}
int hidden_size = w_gate_shape[1] / 3;
// set output
std::vector<int> out_shape(in_shape);
out_shape[2] = hidden_size;
if (GetBidirection()) {
out_shape.insert(out_shape.begin() + 1, 2);
} else {
out_shape.insert(out_shape.begin() + 1, 1);
}
output->set_shape(out_shape);
// set hidden state
std::vector<int> state_shape(in_shape);
state_shape[0] = GetBidirection() ? 2 : 1;
state_shape[2] = hidden_size;
outputs_[1]->set_shape(state_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_GRU_H_
#define MINDSPORE_LITE_SRC_OPS_GRU_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
/*
* gru with linear_before_reset = 0
*/
class Gru : public PrimitiveC {
public:
Gru() = default;
~Gru() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Gru, PrimitiveC);
explicit Gru(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBidirection(bool bidirection);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool GetBidirection() const;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_GRU_H_

@ -25,11 +25,16 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; }
float Lstm::GetSmooth() const { return this->primitive_->value.AsLstm()->smooth; }
void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; }
void Lstm::SetSmooth(float smooth) { this->primitive_->value.AsLstm()->smooth = smooth; }
#else
bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); }
float Lstm::GetSmooth() const { return this->primitive_->value_as_Lstm()->smooth(); }
int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
@ -38,7 +43,7 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
MS_LOG(ERROR) << "value_as_Lstm return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateLstm(*fbb, attr->bidirection());
auto val_offset = schema::CreateLstm(*fbb, attr->bidirection(), attr->smooth());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;

@ -33,12 +33,14 @@ class Lstm : public PrimitiveC {
MS_DECLARE_PARENT(Lstm, PrimitiveC);
explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBidirection(bool bidirection);
void SetSmooth(float smooth);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool GetBidirection() const;
float GetSmooth() const;
};
} // namespace lite
} // namespace mindspore

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/gru.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32/gru_fp32.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateGruParameter(const mindspore::lite::PrimitiveC *primitive) {
GruParameter *gru_param = reinterpret_cast<GruParameter *>(malloc(sizeof(GruParameter)));
if (gru_param == nullptr) {
MS_LOG(ERROR) << "malloc GruParameter failed.";
return nullptr;
}
memset(gru_param, 0, sizeof(GruParameter));
gru_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::Gru *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
if (param == nullptr) {
free(gru_param);
MS_LOG(ERROR) << "get Gru param nullptr.";
return nullptr;
}
gru_param->bidirectional_ = param->GetBidirection();
return reinterpret_cast<OpParameter *>(gru_param);
}
Registry GruParameterRegistry(schema::PrimitiveType_Gru, PopulateGruParameter);
} // namespace lite
} // namespace mindspore

@ -36,6 +36,7 @@ OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive)
return nullptr;
}
lstm_param->bidirectional_ = param->GetBidirection();
lstm_param->smooth_ = param->GetSmooth();
return reinterpret_cast<OpParameter *>(lstm_param);
}
Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter);

@ -161,6 +161,7 @@
#include "src/ops/switch.h"
#include "src/ops/partial.h"
#include "src/ops/gelu.h"
#include "src/ops/gru.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -995,6 +996,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) AssertOP(primitive);
case schema::PrimitiveType_GeLU:
return new (std::nothrow) GeLU(primitive);
case schema::PrimitiveType_Gru:
return new (std::nothrow) Gru(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);

@ -0,0 +1,165 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/gru_fp32.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gru;
namespace mindspore::kernel {
void GruCPUKernel::FreeTmpBuffer() {
if (gate_buffer_ != nullptr) {
free(gate_buffer_);
gate_buffer_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
weight_g_ptr_ = nullptr;
weight_r_ptr_ = nullptr;
}
int GruCPUKernel::InitParam() {
auto input = in_tensors_.front();
MS_ASSERT(input != nullptr);
std::vector<int> in_shape = input->shape();
gru_parm_->seq_len_ = in_shape.at(0);
gru_parm_->batch_ = in_shape.at(1);
gru_parm_->input_size_ = in_shape.at(2);
auto weight_g = in_tensors_.at(1);
MS_ASSERT(weight_g != nullptr);
std::vector<int> w_shape = weight_g->shape();
gru_parm_->hidden_size_ = w_shape.at(1) / 3;
gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_;
gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_
: gru_parm_->batch_ * gru_parm_->hidden_size_;
return RET_OK;
}
int GruCPUKernel::InitBuffer() {
gate_buffer_ = reinterpret_cast<float *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float)));
if (gate_buffer_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error.";
return RET_ERROR;
}
return RET_OK;
}
int GruCPUKernel::InitWeightBias() {
auto weight_gate = in_tensors_.at(1);
MS_ASSERT(weight_gate != nullptr);
weight_g_ptr_ = reinterpret_cast<float *>(weight_gate->data_c());
auto weight_recu = in_tensors_.at(2);
MS_ASSERT(weight_recu != nullptr);
weight_r_ptr_ = reinterpret_cast<float *>(weight_recu->data_c());
int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_;
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error.";
return RET_ERROR;
}
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
const int state_bias_offset = 3 * gru_parm_->hidden_size_;
for (int i = 0; i < state_bias_offset; i++) {
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
}
if (gru_parm_->bidirectional_) {
bias_data += 3 * gru_parm_->hidden_size_ * 2;
auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_;
for (int i = 0; i < state_bias_offset; i++) {
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
}
}
return RET_OK;
}
int GruCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GruCPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = InitParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitParam error.";
return RET_ERROR;
}
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error.";
FreeTmpBuffer();
return RET_ERROR;
}
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitBuffer error.";
FreeTmpBuffer();
return RET_ERROR;
}
return RET_OK;
}
int GruCPUKernel::Run() {
auto input = in_tensors_.at(kInputIndex);
MS_ASSERT(input != nullptr);
auto hidden_state = in_tensors_.at(4);
MS_ASSERT(hidden_state != nullptr);
auto output = out_tensors_.at(0);
MS_ASSERT(output != nullptr);
auto input_ptr = reinterpret_cast<float *>(input->data_c());
MS_ASSERT(input_ptr);
auto output_ptr = reinterpret_cast<float *>(output->MutableData());
MS_ASSERT(output_ptr);
auto output_hidden_state = out_tensors_[1];
memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
int check_seq_len = gru_parm_->seq_len_;
if (in_tensors_.size() == 6) {
auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c());
if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) {
MS_LOG(ERROR) << "different batch seq_len is currently not supported";
return RET_ERROR;
}
check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0]));
}
MS_ASSERT(weight_g_ptr_);
MS_ASSERT(weight_r_ptr_);
MS_ASSERT(bias_ptr_);
MS_ASSERT(gate_buffer_);
Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
reinterpret_cast<float *>(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_);
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gru, LiteKernelCreator<GruCPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp32/gru_fp32.h"
namespace mindspore::kernel {
class GruCPUKernel : public LiteKernel {
public:
GruCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_);
}
~GruCPUKernel() override { FreeTmpBuffer(); }
int Init() override;
int ReSize() override;
int Run() override;
private:
void FreeTmpBuffer();
int InitParam();
int InitBuffer();
int InitWeightBias();
float *gate_buffer_ = nullptr;
const float *weight_g_ptr_ = nullptr;
const float *weight_r_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
GruParameter *gru_parm_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/lstm_fp32.h"
#include <float.h>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -32,6 +33,10 @@ void LstmCPUKernel::FreeTmpBuffer() {
free(gate_buffer_);
gate_buffer_ = nullptr;
}
if (state_buffer_ != nullptr) {
free(state_buffer_);
state_buffer_ = nullptr;
}
if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_);
weight_i_ptr_ = nullptr;
@ -71,6 +76,14 @@ int LstmCPUKernel::InitBuffer() {
MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error.";
return RET_ERROR;
}
if (!(lstm_parm_->smooth_ >= -FLT_EPSILON && lstm_parm_->smooth_ <= FLT_EPSILON)) {
int buffer_size = 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float);
state_buffer_ = reinterpret_cast<float *>(malloc(buffer_size));
if (state_buffer_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error.";
return RET_ERROR;
}
}
return RET_OK;
}
@ -173,7 +186,7 @@ int LstmCPUKernel::Run() {
MS_ASSERT(gate_buffer_);
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
reinterpret_cast<float *>(output_hidden_state->MutableData()),
reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, lstm_parm_);
reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, state_buffer_, lstm_parm_);
return RET_OK;
}

@ -44,6 +44,7 @@ class LstmCPUKernel : public LiteKernel {
int InitWeightBias();
float *gate_buffer_ = nullptr;
float *state_buffer_ = nullptr;
float *weight_i_ptr_ = nullptr;
float *weight_h_ptr_ = nullptr;
float *bias_ptr_ = nullptr;

@ -187,6 +187,9 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc

@ -0,0 +1 @@
decoder_step_201217.pb 5

@ -925,6 +925,41 @@ function Run_arm64() {
fi
done < ${models_compatibility_config}
# Run tf converted models:
while read line; do
model_name=${line}
if [[ $model_name == \#* ]]; then
continue
fi
model_name=`echo ${tf_line_info}|awk -F ' ' '{print $1}'`
input_num=`echo ${tf_line_info}|awk -F ' ' '{print $2}'`
input_files=''
for i in $(seq 1 $input_num)
do
input_files=$input_files'/data/local/tmp/input_output/input/'$model_name'.ms_'$i'.bin,'
done
echo ${model_name} >> "${run_arm64_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> "${run_arm64_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then
run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
# run benchmark test without clib data
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then
run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_tf_config}
# Run tflite converted models:
while read line; do
model_name=${line}

@ -46,6 +46,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/batchmatmul_fusion.cc
../optimizer/fusion/sigmoid_mul_fusion.cc
../optimizer/fusion/conv_conv_fusion.cc
../optimizer/fusion/tflite_lstm_cell_fusion.cc
../optimizer/fusion/tf_lstm_cell_fusion.cc
../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc

@ -29,6 +29,9 @@
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/fusion/conv_conv_fusion.h"
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include "tools/optimizer/graph/identity_remove_pass.h"
@ -114,6 +117,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>());
}
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);

@ -572,7 +572,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) {
type = TypeIdToType(kObjectTypeTensorType);
}
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector));
auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
return RET_ERROR;
}
anf_node->set_abstract(abstract);
anf_node_map->insert(std::pair(op.name(), anf_node));
} else {
AbstractBasePtrList abstractList;
@ -589,6 +594,12 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue};
CNodePtr getItemCNode = anf_graph->NewCNode(inputs);
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
return RET_ERROR;
}
getItemCNode->set_abstract(abstract);
getItemCNode->set_fullname_with_scope(output_item_name);
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode));
}

@ -63,7 +63,11 @@ STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op,
}
*output_size = 1;
return AddOpInput(tf_op, 0, inputs);
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
return AddOpInput(tf_op, 1, inputs);
}
TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser());
} // namespace lite

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_select_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFSelectParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF SelectParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::SwitchT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Switch;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfSelectParser("Select", new TFSelectParser());
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save