|
|
|
@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
"tensor's rank.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims);
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims, false);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
if (x_dims[0] == out_dims[0]) {
|
|
|
|
|
// Only pass LoD when the first dimension of output and Input(X)
|
|
|
|
@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
|
|
|
|
|
const framework::DDim &in_dims) {
|
|
|
|
|
const framework::DDim &in_dims,
|
|
|
|
|
bool is_runtime) {
|
|
|
|
|
size_t num_squeeze_dims = squeeze_dims.size();
|
|
|
|
|
int cnt_squeezed_dims = 0;
|
|
|
|
|
bool should_squeeze[9] = {false};
|
|
|
|
@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
// Check current index, the upper limit has beed checked in line 36.
|
|
|
|
|
PADDLE_ENFORCE(current >= 0,
|
|
|
|
|
"Invalid axis, the negative axis is out of range.");
|
|
|
|
|
PADDLE_ENFORCE(in_dims[current] == 1,
|
|
|
|
|
"Invalid axis index, the axis that will be squeezed "
|
|
|
|
|
"should be equal to 1.");
|
|
|
|
|
|
|
|
|
|
if (is_runtime) {
|
|
|
|
|
PADDLE_ENFORCE(in_dims[current] == 1,
|
|
|
|
|
"Invalid axis index, the axis that will be squeezed "
|
|
|
|
|
"should be equal to 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!(should_squeeze[current])) {
|
|
|
|
|
++cnt_squeezed_dims;
|
|
|
|
@ -104,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto &axes = Attr<std::vector<int>>("axes");
|
|
|
|
|
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
|
|
|
|
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
|
|
|
|
|
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true);
|
|
|
|
|
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(out_dims);
|
|
|
|
@ -224,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto &axes = Attr<std::vector<int>>("axes");
|
|
|
|
|
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
|
|
|
|
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
|
|
|
|
|
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true);
|
|
|
|
|
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(out_dims);
|
|
|
|
|