|
|
|
@ -30,9 +30,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
// Validity Check: input tensor dims (<6).
|
|
|
|
|
PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6,
|
|
|
|
|
"Invalid dimensions, dynamic dimensions should within "
|
|
|
|
|
"[1, 6] dimensions (Eigen limit).");
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() <= 6,
|
|
|
|
|
"Invalid dimensions, the rank of Input(X) "
|
|
|
|
|
"should be in the range of [1, 6] (Eigen limit)");
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
if (x_dims[0] == out_dims[0]) {
|
|
|
|
@ -44,8 +44,8 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
|
|
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
|
|
|
|
|
const framework::DDim &in_dims) {
|
|
|
|
|
int output_size = static_cast<int>(in_dims.size() + unsqz_dims.size());
|
|
|
|
|
int cur_output_size = static_cast<int>(in_dims.size());
|
|
|
|
|
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
|
|
|
|
|
int cur_output_size = in_dims.size();
|
|
|
|
|
std::vector<int64_t> output_shape(output_size, 0);
|
|
|
|
|
|
|
|
|
|
// Validity Check: rank range.
|
|
|
|
@ -110,12 +110,11 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
|
|
|
|
|
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
|
|
|
|
|
AddAttr<std::vector<int>>("axes",
|
|
|
|
|
"(std::vector<int>). List of positive integers,"
|
|
|
|
|
"(std::vector<int>). List of integers,"
|
|
|
|
|
" indicate the dimensions to be inserted")
|
|
|
|
|
.AddCustomChecker([](const std::vector<int> &axes) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
!axes.empty(),
|
|
|
|
|
"The unsqueeze axes information must be set by Attr(axes).");
|
|
|
|
|
PADDLE_ENFORCE(!axes.empty(),
|
|
|
|
|
"Invalid axes, The unsqueeze axes is empty.");
|
|
|
|
|
// Validity Check: axes dims (<6).
|
|
|
|
|
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
|
|
|
|
|
"Invalid dimensions, dynamic dimensions should within "
|
|
|
|
|