|
|
@ -29,10 +29,12 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
|
|
|
|
"Input (Input) of slice op should not be null.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Input (Input) of slice op should not be null."));
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
"Output (Out) of slice op should not be null.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Output (Out) of slice op should not be null."));
|
|
|
|
auto x_var_type = ctx->GetInputsVarType("Input")[0];
|
|
|
|
auto x_var_type = ctx->GetInputsVarType("Input")[0];
|
|
|
|
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
@ -57,7 +59,8 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
PADDLE_ENFORCE_LT(in_dims.size(), 7,
|
|
|
|
PADDLE_ENFORCE_LT(in_dims.size(), 7,
|
|
|
|
"The rank of input should be less than 7.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The rank of input should be less than 7."));
|
|
|
|
framework::DDim out_dims(in_dims);
|
|
|
|
framework::DDim out_dims(in_dims);
|
|
|
|
|
|
|
|
|
|
|
|
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
|
|
|
|
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
|
|
|
@ -76,31 +79,37 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
if (ctx->HasInputs("StartsTensorList")) {
|
|
|
|
if (ctx->HasInputs("StartsTensorList")) {
|
|
|
|
auto StartsTensorList = ctx->Inputs("StartsTensorList");
|
|
|
|
auto StartsTensorList = ctx->Inputs("StartsTensorList");
|
|
|
|
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
|
|
|
|
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
|
|
|
|
"StartsTensorList size can't be zero");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"StartsTensorList size can't be zero"));
|
|
|
|
starts_size = StartsTensorList.size();
|
|
|
|
starts_size = StartsTensorList.size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (ctx->HasInputs("EndsTensorList")) {
|
|
|
|
if (ctx->HasInputs("EndsTensorList")) {
|
|
|
|
auto EndsTensorList = ctx->Inputs("EndsTensorList");
|
|
|
|
auto EndsTensorList = ctx->Inputs("EndsTensorList");
|
|
|
|
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
|
|
|
|
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
|
|
|
|
"EndsTensorList size can't be zero");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"EndsTensorList size can't be zero"));
|
|
|
|
ends_size = EndsTensorList.size();
|
|
|
|
ends_size = EndsTensorList.size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("StartsTensor") == false) {
|
|
|
|
if (ctx->HasInput("StartsTensor") == false) {
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
starts_size, axes.size(),
|
|
|
|
starts_size, axes.size(),
|
|
|
|
"The size of starts must be equal to the size of axes.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The size of starts must be equal to the size of axes."));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (ctx->HasInput("EndsTensor") == false) {
|
|
|
|
if (ctx->HasInput("EndsTensor") == false) {
|
|
|
|
PADDLE_ENFORCE_EQ(ends_size, axes.size(),
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"The size of ends must be equal to the size of axes.");
|
|
|
|
ends_size, axes.size(),
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The size of ends must be equal to the size of axes."));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int dim_value, start, end;
|
|
|
|
int dim_value, start, end;
|
|
|
|
for (size_t i = 0; i < axes.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < axes.size(); ++i) {
|
|
|
|
PADDLE_ENFORCE_LT(static_cast<int>(axes[i]), in_dims.size(),
|
|
|
|
PADDLE_ENFORCE_LT(static_cast<int>(axes[i]), in_dims.size(),
|
|
|
|
"The index of dimension in axes must be less "
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"than the size of input shape.");
|
|
|
|
"The index of dimension in axes must be less "
|
|
|
|
|
|
|
|
"than the size of input shape."));
|
|
|
|
if (infer_flags[i] == -1) {
|
|
|
|
if (infer_flags[i] == -1) {
|
|
|
|
out_dims[axes[i]] = -1;
|
|
|
|
out_dims[axes[i]] = -1;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -112,7 +121,8 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
start = std::max(start, 0);
|
|
|
|
start = std::max(start, 0);
|
|
|
|
end = std::max(end, 0);
|
|
|
|
end = std::max(end, 0);
|
|
|
|
end = std::min(end, dim_value);
|
|
|
|
end = std::min(end, dim_value);
|
|
|
|
PADDLE_ENFORCE_GT(end, start, "end should greater than start");
|
|
|
|
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"end should greater than start"));
|
|
|
|
out_dims[axes[i]] = end - start;
|
|
|
|
out_dims[axes[i]] = end - start;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -122,8 +132,9 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
std::vector<int> new_out_shape;
|
|
|
|
std::vector<int> new_out_shape;
|
|
|
|
for (size_t i = 0; i < decrease_axis.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < decrease_axis.size(); ++i) {
|
|
|
|
if (ctx->IsRuntime() && infer_flags[i] != -1) {
|
|
|
|
if (ctx->IsRuntime() && infer_flags[i] != -1) {
|
|
|
|
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"decrease dim should be 1");
|
|
|
|
out_dims[decrease_axis[i]], 1,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument("decrease dim should be 1"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
out_dims[decrease_axis[i]] = 0;
|
|
|
|
out_dims[decrease_axis[i]] = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -284,9 +295,12 @@ class SliceOpGrad : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
ctx->HasInput("Input"), true,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument("Input should not be null"));
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Input(Out@GRAD) should not be null"));
|
|
|
|
auto x_var_type = ctx->GetInputsVarType("Input")[0];
|
|
|
|
auto x_var_type = ctx->GetInputsVarType("Input")[0];
|
|
|
|
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
// If the var type of input is LOD_TENSOR_ARRAY,
|
|
|
|
// If the var type of input is LOD_TENSOR_ARRAY,
|
|
|
|