|
|
|
@ -30,13 +30,14 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
// Check input tensor dims (<6) Eigen limit.
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() <= 6,
|
|
|
|
|
"Invalid dimnesions, dynamic dimensions must have "
|
|
|
|
|
"between [1, 6] dimensions (Eigen limit).");
|
|
|
|
|
"Invalid dimnesions, the rank of Input(X) "
|
|
|
|
|
"should be in the range of [1, 6] (Eigen limit).");
|
|
|
|
|
|
|
|
|
|
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
|
for (int a : axes) {
|
|
|
|
|
PADDLE_ENFORCE_LT(a, x_dims.size(),
|
|
|
|
|
"The axis must be less than input tensor's rank.");
|
|
|
|
|
"The squeeze axis should be less than input "
|
|
|
|
|
"tensor's rank.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims);
|
|
|
|
@ -50,30 +51,29 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
|
|
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
|
|
|
|
|
const framework::DDim &in_dims) {
|
|
|
|
|
int num_squeeze_dims = static_cast<int>(squeeze_dims.size());
|
|
|
|
|
size_t num_squeeze_dims = squeeze_dims.size();
|
|
|
|
|
int cnt_squeezed_dims = 0;
|
|
|
|
|
bool should_squeeze[9] = {false};
|
|
|
|
|
|
|
|
|
|
// Determines number of dimensions of output tensor after squeeze.
|
|
|
|
|
// Mark and count the dimensions need to be squeezed
|
|
|
|
|
if (num_squeeze_dims == 0) {
|
|
|
|
|
for (int idx = 0; idx < static_cast<int>(in_dims.size()); ++idx) {
|
|
|
|
|
for (int idx = 0; idx < in_dims.size(); ++idx) {
|
|
|
|
|
if (in_dims[idx] == 1) {
|
|
|
|
|
should_squeeze[idx] = true;
|
|
|
|
|
++cnt_squeezed_dims;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int idx = 0; idx < num_squeeze_dims; ++idx) {
|
|
|
|
|
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
|
|
|
|
|
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
|
|
|
|
|
: squeeze_dims[idx];
|
|
|
|
|
// Check current index.
|
|
|
|
|
// Check current index, the upper limit has beed checked in line 36.
|
|
|
|
|
PADDLE_ENFORCE(current >= 0,
|
|
|
|
|
"Invalid axis, negative axis is out of range.");
|
|
|
|
|
// PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
in_dims[current] == 1,
|
|
|
|
|
"Invalid axis index, the axis will be squeezed should be 1.");
|
|
|
|
|
"Invalid axis, the negative axis is out of range.");
|
|
|
|
|
PADDLE_ENFORCE(in_dims[current] == 1,
|
|
|
|
|
"Invalid axis index, the axis that will be squeezed "
|
|
|
|
|
"should equal 1.");
|
|
|
|
|
|
|
|
|
|
if (!(should_squeeze[current])) {
|
|
|
|
|
++cnt_squeezed_dims;
|
|
|
|
@ -84,8 +84,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
|
|
// Make output dimensions
|
|
|
|
|
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
|
|
|
|
|
for (int in_idx = 0, out_idx = 0; in_idx < static_cast<int>(in_dims.size());
|
|
|
|
|
++in_idx) {
|
|
|
|
|
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
|
|
|
|
|
if (!should_squeeze[in_idx]) {
|
|
|
|
|
output_shape[out_idx++] = in_dims[in_idx];
|
|
|
|
|
}
|
|
|
|
@ -123,7 +122,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
|
|
|
|
|
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
|
|
|
|
|
AddAttr<std::vector<int>>("axes",
|
|
|
|
|
"(std::vector<int>). List of positive integers,"
|
|
|
|
|
"(std::vector<int>). List of integers,"
|
|
|
|
|
" indicate the dimensions to squeeze.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<bool>("inplace",
|
|
|
|
|