|
|
|
@ -13,15 +13,73 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/squeeze_op.h"
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
|
|
|
|
|
const framework::DDim &in_dims,
|
|
|
|
|
bool is_runtime) {
|
|
|
|
|
size_t num_squeeze_dims = squeeze_dims.size();
|
|
|
|
|
std::vector<bool> should_squeeze(in_dims.size(), false);
|
|
|
|
|
|
|
|
|
|
// Mark dimensions need to be squeezed.
|
|
|
|
|
if (num_squeeze_dims == 0) {
|
|
|
|
|
for (int i = 0; i < in_dims.size(); ++i) {
|
|
|
|
|
if (in_dims[i] == 1) {
|
|
|
|
|
should_squeeze[i] = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t i = 0; i < num_squeeze_dims; ++i) {
|
|
|
|
|
int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims.size()
|
|
|
|
|
: squeeze_dims[i];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
current, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each axis in Attr(axes) should be in the range of [%d, %d]"
|
|
|
|
|
"But current axis is:%d, input tensor's shape = [%s].",
|
|
|
|
|
-in_dims.size(), in_dims.size() - 1, current, in_dims));
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
current, in_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each axis in Attr(axes) should be in the range of [%d, %d]"
|
|
|
|
|
"But current axis is:%d, input tensor's shape = [%s].",
|
|
|
|
|
-in_dims.size(), in_dims.size() - 1, current, in_dims));
|
|
|
|
|
|
|
|
|
|
if (!should_squeeze[current]) {
|
|
|
|
|
if (is_runtime) {
|
|
|
|
|
// At run time, dim of 1 is allowed to squeeze
|
|
|
|
|
if (in_dims[current] == 1) {
|
|
|
|
|
should_squeeze[current] = true;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// At compile time, dim of -1 or 1 is allowed to squeeze
|
|
|
|
|
if (in_dims[current] == 1 || in_dims[current] == -1) {
|
|
|
|
|
should_squeeze[current] = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Make output dimensions
|
|
|
|
|
std::vector<int64_t> output_shape;
|
|
|
|
|
for (int i = 0; i < in_dims.size(); ++i) {
|
|
|
|
|
if (!should_squeeze[i]) {
|
|
|
|
|
output_shape.push_back(in_dims[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return framework::make_ddim(output_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class SqueezeOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -40,7 +98,7 @@ class SqueezeOp : public framework::OperatorWithKernel {
|
|
|
|
|
x_dims.size(), x_dims));
|
|
|
|
|
|
|
|
|
|
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims);
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims, false);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
if (x_dims[0] == out_dims[0]) {
|
|
|
|
|
// Only pass LoD when the first dimension of output and Input(X)
|
|
|
|
@ -49,56 +107,6 @@ class SqueezeOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
|
|
|
|
|
const framework::DDim &in_dims) {
|
|
|
|
|
size_t num_squeeze_dims = squeeze_dims.size();
|
|
|
|
|
int cnt_squeezed_dims = 0;
|
|
|
|
|
bool should_squeeze[9] = {false};
|
|
|
|
|
|
|
|
|
|
// Determines number of dimensions of output tensor after squeeze.
|
|
|
|
|
// Mark and count the dimensions need to be squeezed
|
|
|
|
|
if (num_squeeze_dims == 0) {
|
|
|
|
|
for (int idx = 0; idx < in_dims.size(); ++idx) {
|
|
|
|
|
if (in_dims[idx] == 1) {
|
|
|
|
|
should_squeeze[idx] = true;
|
|
|
|
|
++cnt_squeezed_dims;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
|
|
|
|
|
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
|
|
|
|
|
: squeeze_dims[idx];
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
current, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each axis in Attr(axes) should be in the range of [%d, %d]"
|
|
|
|
|
"But current axis is:%d, input tensor's shape = [%s].",
|
|
|
|
|
-in_dims.size(), in_dims.size() - 1, current, in_dims));
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
current, in_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each axis in Attr(axes) should be in the range of [%d, %d]"
|
|
|
|
|
"But current axis is:%d, input tensor's shape = [%s].",
|
|
|
|
|
-in_dims.size(), in_dims.size() - 1, current, in_dims));
|
|
|
|
|
|
|
|
|
|
if (!(should_squeeze[current])) {
|
|
|
|
|
++cnt_squeezed_dims;
|
|
|
|
|
}
|
|
|
|
|
should_squeeze[current] = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Make output dimensions
|
|
|
|
|
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
|
|
|
|
|
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
|
|
|
|
|
if (!should_squeeze[in_idx]) {
|
|
|
|
|
output_shape[out_idx++] = in_dims[in_idx];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return framework::make_ddim(output_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
@ -183,7 +191,7 @@ class Squeeze2Op : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
|
|
|
|
|
|
auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims);
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims, false);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
if (x_dims[0] == out_dims[0]) {
|
|
|
|
|
// Only pass LoD when the first dimension of output and Input(X)
|
|
|
|
|