|
|
|
@ -29,14 +29,16 @@ class StridedSliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
|
|
|
|
|
"Input (Input) of slice op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output (Out) of slice op should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice");
|
|
|
|
|
|
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
|
PADDLE_ENFORCE_LT(in_dims.size(), 7,
|
|
|
|
|
"The rank of input should be less than 7.");
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
in_dims.size(), 7,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of StridedSlice operator's input should be less "
|
|
|
|
|
"than 7, but received dimension is %d.",
|
|
|
|
|
in_dims.size()));
|
|
|
|
|
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
|
|
|
|
|
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
|
|
|
|
|
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
|
|
|
|
@ -50,20 +52,26 @@ class StridedSliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInputs("StartsTensorList")) {
|
|
|
|
|
auto StartsTensorList = ctx->Inputs("StartsTensorList");
|
|
|
|
|
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
|
|
|
|
|
"StartsTensorList size can't be zero");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
StartsTensorList.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"StridedSlice operator's StartsTensorList is empty."));
|
|
|
|
|
starts_size = StartsTensorList.size();
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInputs("EndsTensorList")) {
|
|
|
|
|
auto EndsTensorList = ctx->Inputs("EndsTensorList");
|
|
|
|
|
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
|
|
|
|
|
"EndsTensorList size can't be zero");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
EndsTensorList.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"StridedSlice operator's EndsTensorList is empty."));
|
|
|
|
|
ends_size = EndsTensorList.size();
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInputs("StridesTensorList")) {
|
|
|
|
|
auto StridesTensorList = ctx->Inputs("StridesTensorList");
|
|
|
|
|
PADDLE_ENFORCE_GT(StridesTensorList.size(), 0,
|
|
|
|
|
"StridesTensorList size can't be zero");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
StridesTensorList.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"StridedSlice operator's StridesTensorList is empty."));
|
|
|
|
|
strides_size = StridesTensorList.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -73,18 +81,31 @@ class StridedSliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
tensor_input = true;
|
|
|
|
|
}
|
|
|
|
|
if (!ctx->HasInput("EndsTensor")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ends_size, axes.size(),
|
|
|
|
|
"The size of ends must be equal to the size of axes.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ends_size, axes.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of ends attribute in StridedSlice operator is not "
|
|
|
|
|
"equal to the size of axes attribute. The ends attribute's size "
|
|
|
|
|
"is %d, axes attribute's size is %d.",
|
|
|
|
|
ends_size, axes.size()));
|
|
|
|
|
}
|
|
|
|
|
if (!ctx->HasInput("StartsTensor")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
starts_size, axes.size(),
|
|
|
|
|
"The size of starts must be equal to the size of axes.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of starts attribute in StridedSlice operator is not "
|
|
|
|
|
"equal to the size of axes attribute. The starts attribute's "
|
|
|
|
|
"size is %d, axes attribute's size is %d.",
|
|
|
|
|
starts_size, axes.size()));
|
|
|
|
|
}
|
|
|
|
|
if (!ctx->HasInput("StridesTensor")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
strides_size, axes.size(),
|
|
|
|
|
"The size of strides must be equal to the size of axes.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of strides attribute in StridedSlice operator is not "
|
|
|
|
|
"equal to the size of axes attribute. The strides attribute's "
|
|
|
|
|
"size is %d, axes attribute's size is %d.",
|
|
|
|
|
strides_size, axes.size()));
|
|
|
|
|
}
|
|
|
|
|
// we need to analysis strided slice op is valid for
|
|
|
|
|
// the parameter that we get from python front
|
|
|
|
@ -101,7 +122,10 @@ class StridedSliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
for (size_t i = 0; i < decrease_axis.size(); ++i) {
|
|
|
|
|
if (ctx->IsRuntime() && infer_flags[i] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
|
|
|
|
|
"decrease dim should be 1");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"the size of decrease dimension should be 1, "
|
|
|
|
|
"but received %d.",
|
|
|
|
|
out_dims[decrease_axis[i]]));
|
|
|
|
|
}
|
|
|
|
|
out_dims[decrease_axis[i]] = 0;
|
|
|
|
|
}
|
|
|
|
@ -219,9 +243,11 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
|
|
|
|
|
"StridedSliceGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
"Out@GRAD", "StridedSliceGrad");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto x_grad_name = framework::GradVarName("Input");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|