add support tensor and tensorlist for strided_slice OP (#19929)

* add support tensor and tensorlist for strided_slice OP test=develop

* fix the commnet test=develop

* fix test=develop

* fix the bug test=develop

* delete log test=develop

* fix API.spec test=develop

* fix test=develop
expand_as_op_1
wangchaochaohu 5 years ago committed by GitHub
parent fe218df326
commit 382d099dcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -251,7 +251,7 @@ paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype
paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2')) paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2'))
paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310')) paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310'))
paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff')) paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff'))
paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', 'a2e5296d34c081f2a67890aaa5f02238')) paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', '340d8d656272ea396b441aab848429a2'))
paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b')) paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b'))
paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3')) paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3'))
paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe')) paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe'))

@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h" #include "paddle/fluid/operators/strided_slice_op.h"
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/slice_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -39,56 +41,56 @@ class StridedSliceOp : public framework::OperatorWithKernel {
auto ends = ctx->Attrs().Get<std::vector<int>>("ends"); auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides"); auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto axes = ctx->Attrs().Get<std::vector<int>>("axes"); auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
PADDLE_ENFORCE_EQ(starts.size(), ends.size(), auto starts_size = starts.size();
"starts and ends dim size must to be same"); auto ends_size = ends.size();
PADDLE_ENFORCE_EQ(ends.size(), strides.size(), auto strides_size = strides.size();
"ends and strides dim size must to be same");
PADDLE_ENFORCE_EQ(ends.size(), axes.size(),
"axes, end and start dim size must to be same");
// we need to analysis strided slice op is valid for if (ctx->HasInputs("StartsTensorList")) {
// the parameter that we get from python front auto StartsTensorList = ctx->Inputs("StartsTensorList");
int stride_index, start_index, end_index; PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
std::vector<int> out_dims_vector(in_dims.size()); "StartsTensorList size can't be zero");
for (int i = 0; i < in_dims.size(); i++) { starts_size = StartsTensorList.size();
out_dims_vector[i] = in_dims[i];
} }
for (size_t i = 0; i < starts.size(); i++) { if (ctx->HasInputs("EndsTensorList")) {
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero"); auto EndsTensorList = ctx->Inputs("EndsTensorList");
int axes_index = axes[i]; PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
start_index = starts[i]; "EndsTensorList size can't be zero");
end_index = ends[i]; ends_size = EndsTensorList.size();
stride_index = strides[i]; }
int axis_size = in_dims[axes_index]; if (ctx->HasInputs("StridesTensorList")) {
if (axis_size < 0) { auto StridesTensorList = ctx->Inputs("StridesTensorList");
continue; PADDLE_ENFORCE_GT(StridesTensorList.size(), 0,
"StridesTensorList size can't be zero");
strides_size = StridesTensorList.size();
} }
if (start_index < 0) { auto tensor_input = false;
start_index = start_index + axis_size; if (ctx->HasInput("EndsTensor") || ctx->HasInput("StartsTensor") ||
ctx->HasInput("StridesTensor")) {
tensor_input = true;
} }
if (end_index < 0) { if (ctx->HasInput("EndsTensor") == false) {
end_index = end_index + axis_size; PADDLE_ENFORCE_EQ(ends_size, axes.size(),
"The size of ends must be equal to the size of axes.");
} }
if (ctx->HasInput("StartsTensor") == false) {
if (stride_index < 0) { PADDLE_ENFORCE_EQ(
start_index = start_index + 1; starts_size, axes.size(),
end_index = end_index + 1; "The size of starts must be equal to the size of axes.");
} }
if (ctx->HasInput("StridesTensor") == false) {
bool zero_dim_condition = PADDLE_ENFORCE_EQ(
((stride_index < 0 && (start_index <= end_index)) || strides_size, axes.size(),
(stride_index > 0 && (start_index >= end_index))); "The size of strides must be equal to the size of axes.");
PADDLE_ENFORCE_EQ(zero_dim_condition, false, }
"starts and end must meet requirement in different " // we need to analysis strided slice op is valid for
"stride conditiont"); // the parameter that we get from python front
int left = std::max(0, std::min(start_index, end_index)); std::vector<int> out_dims_vector(in_dims.size(), -1);
int right = std::min(axis_size, std::max(start_index, end_index)); if (!tensor_input) {
int step = std::abs(stride_index); StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
auto out_dims_index = (std::abs(right - left) + step - 1) / step; out_dims_vector.data(), axes.size(), true);
out_dims_vector[axes_index] = out_dims_index;
} }
framework::DDim out_dims(framework::make_ddim(out_dims_vector)); framework::DDim out_dims(framework::make_ddim(out_dims_vector));
@ -102,22 +104,79 @@ class StridedSliceOp : public framework::OperatorWithKernel {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.Input<Tensor>("Input")->place()); ctx.Input<Tensor>("Input")->place());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
var_name == "StridesTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StridesTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker { class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Input", "Tensor of data to extract slices from."); AddInput("Input", "Tensor of data to extract slices from.");
AddOutput("Out", "Sliced data tensor."); AddOutput("Out", "Strided Sliced data tensor.");
AddInput("StartsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StartsTensor, StartsTensorList "
"and attr(starts).")
.AsDispensable();
AddInput("EndsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of EndsTensor, EndsTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StridesTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StridesTensor, StridesTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StartsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(starts).")
.AsDuplicable()
.AsDispensable();
AddInput(
"EndsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(ends).")
.AsDuplicable()
.AsDispensable();
AddInput(
"StridesTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(strides).")
.AsDuplicable()
.AsDispensable();
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"axes", "(list<int> Axes stride from the start to the end)"); "axes", "(list<int>) Axes that `starts` and `ends` apply to.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"starts", "(list<int>) start that the tensor slice start."); "starts", "(list<int>) Start indices for the strided slice start.")
.SetDefault({});
AddAttr<std::vector<int>>("ends", AddAttr<std::vector<int>>("ends",
"(list<int>) end that the tensor slice end"); "(list<int>) End indices the tensor slice end")
.SetDefault({});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "(list<int> stride stride from the start to the end)"); "strides", "(list<int> Stride step from the start to the end)")
.SetDefault({});
AddAttr<std::vector<int>>(
"infer_flags", "(list<int>) Flags of inferring dims in attributes.")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
Strided Slice Operator. Strided Slice Operator.
Instead of calling this op directly most users will want to use the Instead of calling this op directly most users will want to use the
@ -150,6 +209,18 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker { class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker {
@ -161,6 +232,12 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker {
auto *bind = new framework::OpDesc(); auto *bind = new framework::OpDesc();
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetInput("Input", Input("Input")); bind->SetInput("Input", Input("Input"));
bind->SetInput("StartsTensor", Input("StartsTensor"));
bind->SetInput("EndsTensor", Input("EndsTensor"));
bind->SetInput("StridesTensor", Input("StridesTensor"));
bind->SetInput("StartsTensorList", Input("StartsTensorList"));
bind->SetInput("EndsTensorList", Input("EndsTensorList"));
bind->SetInput("StridesTensorList", Input("StridesTensorList"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs()); bind->SetAttrMap(Attrs());
bind->SetType("strided_slice_grad"); bind->SetType("strided_slice_grad");

@ -19,9 +19,62 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static void StridedSliceOutDims(
const std::vector<int>& starts, const std::vector<int>& ends,
const std::vector<int>& strides, const std::vector<int>& axes,
const std::vector<int>& infer_flags, const framework::DDim in_dims,
int* out_dims_vector, const size_t size, bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
int stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) {
int axes_index = axes[i];
if (infer_shape && infer_flags[i] == -1) {
out_dims_vector[axes_index] = -1;
continue;
}
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero");
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
int axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
if (start_index < 0) {
start_index = start_index + axis_size;
}
if (end_index < 0) {
end_index = end_index + axis_size;
}
if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
}
bool zero_dim_condition =
((stride_index < 0 && (start_index <= end_index)) ||
(stride_index > 0 && (start_index >= end_index)));
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
"starts and end must meet requirement in different "
"stride conditiont");
int left = std::max(0, std::min(start_index, end_index));
int right = std::min(axis_size, std::max(start_index, end_index));
int step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
}
}
static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
int* reverse_axis, const framework::DDim dims, int* reverse_axis, const framework::DDim dims,
const size_t size) { const size_t size) {
@ -91,19 +144,52 @@ class StridedSliceKernel : public framework::OpKernel<T> {
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input"); auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out"); auto out = context.Output<framework::Tensor>("Out");
auto out_dims = out->dims();
auto in_dims = in->dims(); auto in_dims = in->dims();
auto starts = context.Attr<std::vector<int>>("starts"); auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends"); auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides"); auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>(); auto reverse_axis = Eigen::array<bool, D>();
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
}
std::vector<int> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), false);
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0); std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, starts.size()); reverse_vector.data(), in_dims, starts.size());
@ -112,6 +198,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
starts_indices[axis] = 0; starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis]; ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1; strides_indices[axis] = 1;
reverse_axis[axis] = false;
} }
for (size_t axis = 0; axis < axes.size(); axis++) { for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis]; int axis_index = axes[axis];
@ -124,6 +211,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
framework::Tensor tmp; framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace()); tmp.mutable_data<T>(out_dims, context.GetPlace());
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto in_t = auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
@ -189,6 +277,34 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
auto strides = context.Attr<std::vector<int>>("strides"); auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
}
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();

@ -11398,55 +11398,146 @@ def strided_slice(input, axes, starts, ends, strides):
ends = [2, 3] ends = [2, 3]
strides=[1, 1] strides=[1, 1]
Then: Then:
result = [ [5, 6, 7] ] result = [ [5, 6, 7], ]
Case2: Case2:
Given: Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1] axes = [0, 1]
starts = [0, -1] starts = [0, 1]
ends = [-1, 0] ends = [-1, 1000]
strides = [1, -1] strides = [1, 3]
Then: Then:
result = [ [4, 3, 2] ] result = [ [2], ]
Atrgs: Args:
input (Varibale): the input variable. input (Variable): ${input_comment}.
axes(List):axis we need to slice axes (List): ${axes_comment}
starts (List): the start index in axis starts (List|Variable): ${starts_comment}
ends (List): the end index in axis ends (List|Variable): ${ends_comment}
strides (List): the stride length when we do slice operation
Returns Returns:
out(Variable): the result by strided_slice Op out (Variable): ${out_comment}
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
starts = [1, 0, 2]
ends = [3, 3, 4]
axes = [0, 1, 2]
strides= [1, 1, 1]
input = fluid.layers.data( input = fluid.layers.data(
name="input", shape=[3, 4, 5, 6], dtype='float32') name="input", shape=[3, 4, 5, 6], dtype='float32')
out = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides) # example 1:
# attr starts is a list which doesn't contain tensor Variable.
axes = [0, 1, 2]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides=[1, 1, 1]
sliced_1 = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides)
# example 2:
# attr starts is a list which contain tensor Variable.
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides)
""" """
if not isinstance(starts, (list, tuple, Variable)):
raise ValueError(
"Input starts must be an Variable, python list or tuple.")
if not isinstance(ends, (list, tuple, Variable)):
raise ValueError(
"Input ends must be an Variable, python list or tuple.")
if not isinstance(strides, (list, tuple, Variable)):
raise ValueError(
"Input strides must be an Variable, python list or tuple.")
helper = LayerHelper('strided_slice', **locals()) helper = LayerHelper('strided_slice', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op( def contain_var(one_list):
type='strided_slice', for ele in one_list:
inputs={'Input': input}, if isinstance(ele, Variable):
outputs={'Out': out}, return True
return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': input}
attrs = {'axes': axes}
infer_flags = list(1 for i in range(len(axes)))
if in_dygraph_mode():
inputs = {'Input': input}
attrs = { attrs = {
'axes': axes, 'axes': axes,
'starts': starts, 'starts': starts,
'ends': ends, 'ends': ends,
'strides': strides 'strides': strides,
}) 'infer_flags': infer_flags
}
else:
# starts
if isinstance(starts, Variable):
starts.stop_gradient = True
inputs['StartsTensor'] = starts
elif isinstance(starts, (list, tuple)):
attrs['starts'] = []
if not contain_var(starts):
attrs['starts'] = starts
else:
inputs['StartsTensorList'] = get_new_list_tensor(starts)
for i, dim in enumerate(starts):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
# ends
if isinstance(ends, Variable):
ends.stop_gradient = True
inputs['EndsTensor'] = ends
elif isinstance(ends, (list, tuple)):
attrs['ends'] = []
if not contain_var(ends):
attrs['ends'] = ends
else:
inputs['EndsTensorList'] = get_new_list_tensor(ends)
for i, dim in enumerate(ends):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# strides
if isinstance(strides, Variable):
strides.stop_gradient = True
inputs['StridesTensor'] = strides
elif isinstance(strides, (list, tuple)):
attrs['strides'] = []
if not contain_var(strides):
attrs['strides'] = strides
else:
inputs['StridesTensorList'] = get_new_list_tensor(strides)
for i, dim in enumerate(strides):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
attrs['infer_flags'] = infer_flags
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op(
type='strided_slice', inputs=inputs, attrs=attrs, outputs={'Out': out})
return out return out

Loading…
Cancel
Save