|
|
|
@ -23,9 +23,9 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of SqueezeOp should not be null.");
|
|
|
|
|
"Input(X) of Squeeze operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SqueezeOp should not be null.");
|
|
|
|
|
"Output(Out) of Squeeze operator should not be null.");
|
|
|
|
|
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
// Check input tensor dims (<6) Eigen limit.
|
|
|
|
@ -107,7 +107,6 @@ class SqueezeOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(out_dims);
|
|
|
|
|
attrs["inplace"] = Attr<bool>("inplace");
|
|
|
|
|
// Invoke Reshape Op
|
|
|
|
|
auto reshape_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
|
|
|
|
@ -125,12 +124,6 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(std::vector<int>). List of integers,"
|
|
|
|
|
" indicating the dimensions to squeeze.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<bool>("inplace",
|
|
|
|
|
"(default: false) Squeeze the source tensor's shape without "
|
|
|
|
|
"memory copy. When Attr(inplace) is set true, the output "
|
|
|
|
|
"tensor shares memory with Input(X), otherwise, a new output "
|
|
|
|
|
"tensor is created, and its data are copied from Input(x).")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Squeeze Operator.
|
|
|
|
|
|
|
|
|
@ -180,7 +173,6 @@ class SqueezeGradOp : public framework::OperatorBase {
|
|
|
|
|
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(x_dims);
|
|
|
|
|
attrs["inplace"] = Attr<bool>("inplace");
|
|
|
|
|
|
|
|
|
|
auto reshape_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
|
|
|
|
|