|
|
|
@ -54,10 +54,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
n, static_cast<size_t>(length->dims()[0]),
|
|
|
|
|
"The size of input-sequence and length-array should be the same")
|
|
|
|
|
"The size of input-sequence and length-array should be the same");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
n, static_cast<size_t>(offset->dims()[0]),
|
|
|
|
|
"The size of input-sequence and offset-array should be the same")
|
|
|
|
|
"The size of input-sequence and offset-array should be the same");
|
|
|
|
|
|
|
|
|
|
const int64_t* offset_data = offset->data<int64_t>();
|
|
|
|
|
const int64_t* length_data = length->data<int64_t>();
|
|
|
|
@ -78,11 +78,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < n; ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(0, offset_data[i],
|
|
|
|
|
"The offset[%d] must greater than zero.", i)
|
|
|
|
|
"The offset[%d] must greater than zero.", i);
|
|
|
|
|
PADDLE_ENFORCE_LT(0, length_data[i],
|
|
|
|
|
"The length[%d] must greater than zero.", i)
|
|
|
|
|
"The length[%d] must greater than zero.", i);
|
|
|
|
|
PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
|
|
|
|
|
lod[0][i + 1], "The target tensor's length overflow.")
|
|
|
|
|
lod[0][i + 1], "The target tensor's length overflow.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|