|
|
|
@ -127,13 +127,13 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Unsqueeze Operator.
|
|
|
|
|
|
|
|
|
|
Insert single-dimensional entries to the shape of a tensor.
|
|
|
|
|
Takes one required argument axes, a list of dimensions that will be inserted.
|
|
|
|
|
Dimension indices in axes are as seen in the output tensor.
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
Given a tensor such that tensor with shape [3, 4, 5],
|
|
|
|
|
Insert single-dimensional entries to the shape of a tensor.
|
|
|
|
|
Takes one required argument axes, a list of dimensions that will be inserted.
|
|
|
|
|
Dimension indices in axes are as seen in the output tensor.
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
Given a tensor such that tensor with shape [3, 4, 5],
|
|
|
|
|
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
@ -168,6 +168,112 @@ class UnsqueezeGradOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
|
|
|
|
|
// unsqueeze, the XShape is used to carry the shape and lod of X which
|
|
|
|
|
// will be used in unsqueeze_grad, in this way, the framework can reuse
|
|
|
|
|
// the memory of X immediately the unsqueeze2_op is finished.
|
|
|
|
|
// Considering compatibility issues, we could not fix unsqueeze2_op
|
|
|
|
|
class Unsqueeze2OpInferShape : public UnsqueezeOpInferShape {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
UnsqueezeOpInferShape::operator()(ctx);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
|
|
|
|
|
"Output(XShape) of Unsqueeze operator should not be null.");
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
|
|
|
|
|
xshape_dims[0] = 0;
|
|
|
|
|
for (int i = 0; i < x_dims.size(); ++i) {
|
|
|
|
|
xshape_dims[i + 1] = x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "XShape");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
UnsqueezeOpMaker::Make();
|
|
|
|
|
AddOutput("XShape",
|
|
|
|
|
"XShape is just used to store the shape and lod of X, which will "
|
|
|
|
|
"be used in UnsqueezeGradOp.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Unsqueeze2Op : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using OperatorBase::OperatorBase;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto &axes = Attr<std::vector<int>>("axes");
|
|
|
|
|
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
|
|
|
|
auto out_dims = Unsqueeze2OpInferShape::GetOutputShape(axes, x_dims);
|
|
|
|
|
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(out_dims);
|
|
|
|
|
// Invoke Reshape op.
|
|
|
|
|
auto reshape_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
|
|
|
|
|
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
|
|
|
|
|
reshape_op->Run(scope, place);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto *grad_op = new framework::OpDesc();
|
|
|
|
|
grad_op->SetType("unsqueeze2_grad");
|
|
|
|
|
grad_op->SetInput("XShape", Output("XShape"));
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
grad_op->SetAttrMap(Attrs());
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(grad_op);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Unsqueeze2GradInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("XShape"),
|
|
|
|
|
"Input(XShape) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
auto xshape_dims = context->GetInputDim("XShape");
|
|
|
|
|
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
|
|
|
|
|
context->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
context->ShareLoD("XShape", framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Unsqueeze2GradOp : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using OperatorBase::OperatorBase;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto dx_name = Output(framework::GradVarName("X"));
|
|
|
|
|
auto dout_name = Input(framework::GradVarName("Out"));
|
|
|
|
|
auto xshape_name = Input("XShape");
|
|
|
|
|
auto xshape_dims =
|
|
|
|
|
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
|
|
|
|
|
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
|
|
|
|
|
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(x_dims);
|
|
|
|
|
|
|
|
|
|
auto reshape_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
|
|
|
|
|
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
|
|
|
|
|
reshape_op->Run(scope, place);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -180,3 +286,8 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
|
|
|
|
|
ops::UnsqueezeGradInferShape);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
|
|
|
|
|
ops::Unsqueeze2OpInferShape, ops::Unsqueeze2GradOpMaker);
|
|
|
|
|
REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
|
|
|
|
|
ops::Unsqueeze2GradInferShape);
|
|
|
|
|