!5994 [MSLITE][Develop] reverse_seqence seq_lengths support int64

Merge pull request !5994 from sunsuodong/fix_reverse_seqence
pull/5994/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7913f8361c

@ -18,7 +18,7 @@
#include <string.h>
#include "nnacl/arithmetic_common.h"
void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para) {
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para) {
(void)memcpy(output, input0, para->total_data_size_);
ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_);
ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_);
@ -28,8 +28,9 @@ void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceP
for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) {
float *in_batch = in + batch * para->input_stride_[para->batch_axis_];
float *out_batch = out + batch * para->output_stride_[para->batch_axis_];
for (int n = 0; n < input1[batch]; ++n) {
float *in_seq = in_batch + (input1[batch] - 1 - n) * para->input_stride_[para->seq_axis_];
int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch);
for (int n = 0; n < seq_length; ++n) {
float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_];
float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_];
for (int j = 0; j < para->inner_count_; ++j) {
(void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_);

@ -34,12 +34,13 @@ typedef struct ReverseSequenceParameter {
int inner_stride_;
int copy_byte_size_;
int total_data_size_;
bool is_seq_length_int32_;
} ReverseSequenceParameter;
#ifdef __cplusplus
extern "C" {
#endif
void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para);
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para);
#ifdef __cplusplus
}
#endif

@ -93,9 +93,11 @@ int ReverseSequenceCPUKernel::Run() {
return ret;
}
float *input0 = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
int *input1 = reinterpret_cast<int *>(in_tensors_.at(1)->MutableData());
void *input1 = in_tensors_.at(1)->MutableData();
float *output = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
ReverseSequence(input0, input1, output, reinterpret_cast<ReverseSequenceParameter *>(op_parameter_));
ReverseSequenceParameter *param = reinterpret_cast<ReverseSequenceParameter *>(op_parameter_);
param->is_seq_length_int32_ = in_tensors_.at(1)->data_type() == kNumberTypeInt32;
ReverseSequence(input0, input1, output, param);
return RET_OK;
}

Loading…
Cancel
Save