|
|
|
@ -27,8 +27,12 @@ inline std::vector<int> get_new_shape(
|
|
|
|
|
std::vector<int> vec_new_shape;
|
|
|
|
|
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
|
|
|
|
|
auto tensor = list_new_shape_tensor[i];
|
|
|
|
|
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
|
|
|
|
|
"shape of dim tensor should be [1]");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
tensor->dims(), framework::make_ddim({1}),
|
|
|
|
|
"ShapeError: If the element type of 'shape' in ReshapeOp is Tensor, "
|
|
|
|
|
"the element's shape must be [1]. But received the element's shape "
|
|
|
|
|
"is [%s]",
|
|
|
|
|
tensor->dims());
|
|
|
|
|
if (platform::is_gpu_place(tensor->place())) {
|
|
|
|
|
framework::Tensor temp;
|
|
|
|
|
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
|
|
|
|
@ -58,8 +62,12 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (ctx->HasInputs("ShapeTensor")) {
|
|
|
|
|
// top prority shape
|
|
|
|
|
auto ShapeTensor = ctx->Inputs("ShapeTensor");
|
|
|
|
|
PADDLE_ENFORCE_GT(ShapeTensor.size(), 0,
|
|
|
|
|
"The size of Input(ShapeTensor) can't be zero");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
ShapeTensor.size(), 0,
|
|
|
|
|
"ShapeError: When `shape` in ReshapeOp is a list or tuple "
|
|
|
|
|
"which contains Tensor, the shape's size can't be zero. "
|
|
|
|
|
"But received shape's size is %d.",
|
|
|
|
|
ShapeTensor.size());
|
|
|
|
|
auto infer_shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
const int64_t copy_dim_val = 0;
|
|
|
|
|
auto in_dims = ctx->GetInputDim("X");
|
|
|
|
@ -67,8 +75,10 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (infer_shape[i] == copy_dim_val) {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
static_cast<int>(i), in_dims.size(),
|
|
|
|
|
"The dimension of data to copy from input must be less "
|
|
|
|
|
"than the dimension of input.");
|
|
|
|
|
"ShapeError: The index of 0 in `shape` must be less than "
|
|
|
|
|
"the input tensor X's dimensions. But received shape[%d] "
|
|
|
|
|
"= 0, X's dimensions = %d, X's shape = [%s].",
|
|
|
|
|
i, in_dims.size(), in_dims);
|
|
|
|
|
infer_shape[i] = in_dims[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -98,8 +108,10 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(!shape.empty(), true,
|
|
|
|
|
"The shape information must be set by Attr(shape).");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
!shape.empty(), true,
|
|
|
|
|
"ShapeError: The parameter 'shape' in ReshapeOp must be set. "
|
|
|
|
|
"But received 'shape' is empty.");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto out_dims = ValidateShape(shape, x_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
@ -128,18 +140,25 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (shape[i] == unk_dim_val) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
unk_dim_idx, -1,
|
|
|
|
|
"Only one input dimension of Attr(shape) can be unknown.");
|
|
|
|
|
"ShapeError: Only one dimension value of 'shape' in ReshapeOp can "
|
|
|
|
|
"be -1. But received shape = [%s], shape[%d] is also -1.",
|
|
|
|
|
framework::make_ddim(shape), i);
|
|
|
|
|
unk_dim_idx = i;
|
|
|
|
|
} else if (shape[i] == copy_dim_val) {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
static_cast<int>(i), in_dims.size(),
|
|
|
|
|
"The index of dimension to copy from input shape must be less "
|
|
|
|
|
"than the size of input shape.");
|
|
|
|
|
"ShapeError: The index of 0 in `shape` must be less than "
|
|
|
|
|
"the input tensor X's dimensions. "
|
|
|
|
|
"But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
|
|
|
|
|
"X's dimensions = %d.",
|
|
|
|
|
framework::make_ddim(shape), i, in_dims, in_dims.size());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
shape[i], 0,
|
|
|
|
|
"Each input dimension of Attr(shape) must not be negtive except "
|
|
|
|
|
"one unknown dimension.");
|
|
|
|
|
"ShapeError: Each dimension value of 'shape' in ReshapeOp must not "
|
|
|
|
|
"be negtive except one unknown dimension. "
|
|
|
|
|
"But received shape = [%s], shape[%d] = %d.",
|
|
|
|
|
framework::make_ddim(shape), i, shape[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
capacity *= (shape[i] ? shape[i] : in_dims[i]);
|
|
|
|
@ -155,12 +174,25 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
// the following check will fail.
|
|
|
|
|
output_shape[unk_dim_idx] = -in_size / capacity;
|
|
|
|
|
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
|
|
|
|
|
"Invalid shape is given.");
|
|
|
|
|
"ShapeError: The 'shape' in ReshapeOp is invalid. "
|
|
|
|
|
"The input tensor X'size must be divisible by known "
|
|
|
|
|
"capacity of 'shape'. "
|
|
|
|
|
"But received X's shape = [%s], X's size = %d, "
|
|
|
|
|
"'shape' is [%s], known "
|
|
|
|
|
"capacity of 'shape' is %d.",
|
|
|
|
|
in_dims, in_size, framework::make_ddim(shape),
|
|
|
|
|
capacity);
|
|
|
|
|
} else {
|
|
|
|
|
output_shape[unk_dim_idx] = -1;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
capacity, in_size,
|
|
|
|
|
"ShapeError: The 'shape' in ReshapeOp is invalid. "
|
|
|
|
|
"The input tensor X'size must be equal to the capacity of 'shape'. "
|
|
|
|
|
"But received X's shape = [%s], X's size = %d, 'shape' is [%s], the "
|
|
|
|
|
"capacity of 'shape' is %d.",
|
|
|
|
|
in_dims, in_size, framework::make_ddim(shape), capacity);
|
|
|
|
|
}
|
|
|
|
|
return framework::make_ddim(output_shape);
|
|
|
|
|
}
|
|
|
|
@ -188,22 +220,25 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X", "(Tensor). The input tensor of reshape operator.");
|
|
|
|
|
AddInput("Shape",
|
|
|
|
|
"(Tensor<int32>, optional). If provided, reshape according to "
|
|
|
|
|
"this given shape. That is to say it has a higher priority than "
|
|
|
|
|
"the shape attribute, while the shape attribute still should be "
|
|
|
|
|
"(Tensor<int32>, optional). Target shape of reshape operator. "
|
|
|
|
|
"It has a higher priority than Attr(shape) but a lower priority "
|
|
|
|
|
"than Input(ShapeTensor). The Attr(shape) still should be "
|
|
|
|
|
"set correctly to gurantee shape inference in compile time.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddInput(
|
|
|
|
|
"ShapeTensor",
|
|
|
|
|
"(vector<Tensor<int32>>, optional). If provided, reshape will use this"
|
|
|
|
|
"The shape of the tensor in vector MUST BE [1]"
|
|
|
|
|
"it has the highest priority compare with Input(Shape) and "
|
|
|
|
|
"attr(shape).")
|
|
|
|
|
"(vector<Tensor<int32>>, optional). Target shape of reshape operator. "
|
|
|
|
|
"It has the highest priority compare with Input(Shape) and "
|
|
|
|
|
"Attr(shape)."
|
|
|
|
|
"The shape of the element in vector must be [1].")
|
|
|
|
|
.AsDuplicable()
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("Out", "(Tensor). The output tensor of reshape operator.");
|
|
|
|
|
AddAttr<std::vector<int>>(
|
|
|
|
|
"shape", "(std::vector<int>) Target shape of reshape operator.")
|
|
|
|
|
"shape",
|
|
|
|
|
"(std::vector<int>) Target shape of reshape operator."
|
|
|
|
|
"It has the lowest priority compare with Input(Shape) and "
|
|
|
|
|
" Input(ShapeTensor).")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Reshape Operator.
|
|
|
|
|