|
|
|
@ -18,7 +18,7 @@
|
|
|
|
|
#include <string.h>
|
|
|
|
|
#include "nnacl/arithmetic_common.h"
|
|
|
|
|
|
|
|
|
|
void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para) {
|
|
|
|
|
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para) {
|
|
|
|
|
(void)memcpy(output, input0, para->total_data_size_);
|
|
|
|
|
ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_);
|
|
|
|
|
ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_);
|
|
|
|
@ -28,8 +28,9 @@ void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceP
|
|
|
|
|
for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) {
|
|
|
|
|
float *in_batch = in + batch * para->input_stride_[para->batch_axis_];
|
|
|
|
|
float *out_batch = out + batch * para->output_stride_[para->batch_axis_];
|
|
|
|
|
for (int n = 0; n < input1[batch]; ++n) {
|
|
|
|
|
float *in_seq = in_batch + (input1[batch] - 1 - n) * para->input_stride_[para->seq_axis_];
|
|
|
|
|
int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch);
|
|
|
|
|
for (int n = 0; n < seq_length; ++n) {
|
|
|
|
|
float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_];
|
|
|
|
|
float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_];
|
|
|
|
|
for (int j = 0; j < para->inner_count_; ++j) {
|
|
|
|
|
(void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_);
|
|
|
|
|