|
|
|
@ -83,7 +83,8 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
lod[0][i] + offset_data[i] + length_data[i],
|
|
|
|
|
lod[0][i + 1],
|
|
|
|
|
"The target tensor's length overflow")}
|
|
|
|
|
"The target tensor's length overflow")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
|
|
|
|
@ -140,27 +141,29 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto lod = in->lod();
|
|
|
|
|
auto out_lod = out_grad->lod();
|
|
|
|
|
|
|
|
|
|
x_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<Place, T> set_zero;
|
|
|
|
|
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
|
|
|
|
|
if (x_grad) {
|
|
|
|
|
x_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<Place, T> set_zero;
|
|
|
|
|
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
auto out_grad_stride = framework::stride(out_grad->dims());
|
|
|
|
|
auto out_grad_stride = framework::stride(out_grad->dims());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
|
|
|
|
|
Tensor out_grad_t =
|
|
|
|
|
out_grad->Slice(static_cast<int>(out_lod[0][i]),
|
|
|
|
|
static_cast<int>(out_lod[0][i + 1]));
|
|
|
|
|
auto out_grad_stride = framework::stride(out_grad_t.dims());
|
|
|
|
|
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
|
|
|
|
|
Tensor out_grad_t =
|
|
|
|
|
out_grad->Slice(static_cast<int>(out_lod[0][i]),
|
|
|
|
|
static_cast<int>(out_lod[0][i + 1]));
|
|
|
|
|
auto out_grad_stride = framework::stride(out_grad_t.dims());
|
|
|
|
|
|
|
|
|
|
auto x_grad_stride = framework::stride(x_grad->dims());
|
|
|
|
|
auto x_grad_stride = framework::stride(x_grad->dims());
|
|
|
|
|
|
|
|
|
|
Tensor x_grad_t = x_grad->Slice(
|
|
|
|
|
static_cast<int>(lod[0][i] + offset_data[i]),
|
|
|
|
|
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
|
|
|
|
|
Tensor x_grad_t = x_grad->Slice(
|
|
|
|
|
static_cast<int>(lod[0][i] + offset_data[i]),
|
|
|
|
|
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
|
|
|
|
|
|
|
|
|
|
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
|
|
|
|
|
out_grad_stride, out_grad_t.dims(), x_grad_stride,
|
|
|
|
|
x_grad_t.data<T>());
|
|
|
|
|
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
|
|
|
|
|
out_grad_stride, out_grad_t.dims(), x_grad_stride,
|
|
|
|
|
x_grad_t.data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|