Optimize the computing kernel of sequence_reverse operator (#17349)

* Optimize the computing kernel of sequence_reverse operator.

test=develop

* Clean code

test=develop

* Fix for cpplint syntax checking.

test=develop

* Fix the compile warning issue.

test=develop
revert-17080-prepare_data
Yihua Xu 6 years ago committed by Tao Luo
parent dcda20233c
commit 218d8d8f73

@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/for_range.h"
@ -109,7 +110,6 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(x.lod().size(), 1,
"SequenceReverse Op only support one level lod.");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const size_t *lod;
size_t lod_count = x.lod()[0].size();
@ -131,10 +131,24 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NE(x_data, y_data,
"SequenceReverse Op does not support in-place operation");
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
row_numel);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
if (platform::is_cpu_place(ctx.GetPlace())) {
for (size_t idx = 0; idx < lod_count - 1; idx++) {
auto start_pos = lod[idx];
auto end_pos = lod[idx + 1];
for (auto pos = start_pos; pos < end_pos; pos++) {
auto cur_pos = end_pos - pos - 1 + start_pos;
std::memcpy(y_data + pos * row_numel, x_data + cur_pos * row_numel,
row_numel * sizeof(T));
}
}
} else {
auto &dev_ctx = ctx.template device_context<DeviceContext>();
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
row_numel);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
}
}
};

Loading…
Cancel
Save