|
|
|
@ -241,8 +241,6 @@ class StridedSliceKernel : public framework::OpKernel<T> {
|
|
|
|
|
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Tensor tmp;
|
|
|
|
|
|
|
|
|
|
auto out_dims_origin = out_dims;
|
|
|
|
|
if (decrease_axis.size() > 0) {
|
|
|
|
|
std::vector<int> new_out_shape;
|
|
|
|
@ -263,21 +261,34 @@ class StridedSliceKernel : public framework::OpKernel<T> {
|
|
|
|
|
out_dims_origin = framework::make_ddim(new_out_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tmp.mutable_data<T>(out_dims, context.GetPlace());
|
|
|
|
|
bool need_reverse = false;
|
|
|
|
|
for (size_t axis = 0; axis < axes.size(); axis++) {
|
|
|
|
|
if (reverse_vector[axis] == 1) {
|
|
|
|
|
need_reverse = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto in_t =
|
|
|
|
|
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
|
|
|
|
|
*in);
|
|
|
|
|
auto tmp_t =
|
|
|
|
|
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
|
|
|
|
|
tmp);
|
|
|
|
|
auto out_t =
|
|
|
|
|
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
|
|
|
|
|
*out, out_dims);
|
|
|
|
|
tmp_t.device(place) =
|
|
|
|
|
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
|
|
|
|
|
out_t.device(place) = tmp_t.reverse(reverse_axis);
|
|
|
|
|
if (need_reverse) {
|
|
|
|
|
framework::Tensor tmp;
|
|
|
|
|
tmp.mutable_data<T>(out_dims, context.GetPlace());
|
|
|
|
|
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
|
|
|
|
|
Eigen::DenseIndex>::From(tmp);
|
|
|
|
|
tmp_t.device(place) =
|
|
|
|
|
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
|
|
|
|
|
out_t.device(place) = tmp_t.reverse(reverse_axis);
|
|
|
|
|
} else {
|
|
|
|
|
out_t.device(place) =
|
|
|
|
|
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (decrease_axis.size() > 0) {
|
|
|
|
|
out->Resize(out_dims_origin);
|
|
|
|
@ -388,22 +399,33 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Tensor reverse_input;
|
|
|
|
|
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
|
|
|
|
|
|
|
|
|
|
bool need_reverse = false;
|
|
|
|
|
for (size_t axis = 0; axis < axes.size(); axis++) {
|
|
|
|
|
if (reverse_vector[axis] == 1) {
|
|
|
|
|
need_reverse = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto in_t =
|
|
|
|
|
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
|
|
|
|
|
*d_input);
|
|
|
|
|
auto reverse_in_t =
|
|
|
|
|
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
|
|
|
|
|
reverse_input);
|
|
|
|
|
auto out_t =
|
|
|
|
|
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
|
|
|
|
|
*d_out, out_dims);
|
|
|
|
|
|
|
|
|
|
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
|
|
|
|
|
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
|
|
|
|
|
.device(place) = reverse_in_t;
|
|
|
|
|
if (need_reverse) {
|
|
|
|
|
framework::Tensor reverse_input;
|
|
|
|
|
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
|
|
|
|
|
auto reverse_in_t =
|
|
|
|
|
framework::EigenTensor<T, D, Eigen::RowMajor,
|
|
|
|
|
Eigen::DenseIndex>::From(reverse_input);
|
|
|
|
|
|
|
|
|
|
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
|
|
|
|
|
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
|
|
|
|
|
.device(place) = reverse_in_t;
|
|
|
|
|
} else {
|
|
|
|
|
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
|
|
|
|
|
.device(place) = in_t;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|