|
|
|
@ -27,24 +27,35 @@ class UnStackOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must exist.");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UnStack");
|
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
int num = ctx->Attrs().Get<int>("num");
|
|
|
|
|
auto x_dim = ctx->GetInputDim("X");
|
|
|
|
|
int rank = x_dim.size();
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
axis, -rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
|
|
|
|
|
rank);
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
axis, rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
|
|
|
|
|
rank);
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, -rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The attribute axis is out of range, it must be "
|
|
|
|
|
"inside [-rank, rank), where rank = %d",
|
|
|
|
|
rank));
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The attribute axis is out of range, it must be "
|
|
|
|
|
"inside [-rank, rank), where rank = %d",
|
|
|
|
|
rank));
|
|
|
|
|
if (axis < 0) axis += rank;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
|
|
|
|
|
"Number of Outputs(Y) is wrong");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Number of Outputs(Y) is wrong. Got %d , but it must "
|
|
|
|
|
"equal to attribute num which is %d.",
|
|
|
|
|
ctx->Outputs("Y").size(), static_cast<size_t>(num)));
|
|
|
|
|
if (x_dim[axis] > 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
num, x_dim[axis],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of attribute num is not equal to the length of the "
|
|
|
|
|
"%d axis of Input(X). Expect %d but got %d.",
|
|
|
|
|
axis, x_dim[axis], num));
|
|
|
|
|
}
|
|
|
|
|
auto vec = framework::vectorize<int>(x_dim);
|
|
|
|
|
vec.erase(vec.begin() + axis);
|
|
|
|
@ -89,24 +100,29 @@ class UnStackGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
|
|
|
|
|
"Number of Inputs(Y@Grad) must be larger than 0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|
"Output(X@Grad) must exist.");
|
|
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Number of Inputs(Y@Grad) must be larger than 0"));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", "X",
|
|
|
|
|
"UnStackGrad");
|
|
|
|
|
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
|
|
|
|
|
for (size_t i = 1; i < input_dims.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
|
|
|
|
|
"Dims of all Inputs(Y@Grad) must be the same");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Dims of all Inputs(Y@Grad) must be the same"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
int rank = input_dims[0].size();
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
axis, -(rank + 1),
|
|
|
|
|
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
axis, rank + 1,
|
|
|
|
|
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, -(rank + 1),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The attribute axis is out of range, it must be "
|
|
|
|
|
"inside [-(rank+1), rank+1), where rank = %d",
|
|
|
|
|
rank));
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, rank + 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The attribute axis is out of range, it must be "
|
|
|
|
|
"inside [-(rank+1), rank+1), where rank = %d",
|
|
|
|
|
rank));
|
|
|
|
|
if (axis < 0) axis += (rank + 1);
|
|
|
|
|
|
|
|
|
|
auto vec = framework::vectorize<int>(input_dims[0]);
|
|
|
|
|