|
|
|
@ -24,26 +24,28 @@ class ExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
|
|
|
|
|
std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), expand_times.size(),
|
|
|
|
|
"The number of expandTimes's value must be equal "
|
|
|
|
|
"to the rank of X.");
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
|
|
|
|
|
std::vector<int> expand_times =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("expandTimes");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
|
|
|
|
|
"The number of Attr(expandTimes)'s value must be equal "
|
|
|
|
|
"to the rank of Input(X).");
|
|
|
|
|
PADDLE_ENFORCE_LE(x_dims.size(), 6,
|
|
|
|
|
"The rank of X must not be greater than 6.");
|
|
|
|
|
"The rank of Input(X) must not be greater than 6.");
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_shape(x_dims.size());
|
|
|
|
|
for (size_t i = 0; i < expand_times.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_GE(expand_times[i], 1,
|
|
|
|
|
"Each value of expandTimes should not be "
|
|
|
|
|
"Each value of Attr(expandTimes) should not be "
|
|
|
|
|
"less than 1.");
|
|
|
|
|
out_shape[i] = x_dims[i] * expand_times[i];
|
|
|
|
|
}
|
|
|
|
|
auto* out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
out->Resize(framework::make_ddim(out_shape));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -52,20 +54,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"The input tensor of expand op."
|
|
|
|
|
"The rank of X should be between in 1 and 6.");
|
|
|
|
|
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
|
|
|
|
|
"X is the input tensor to be expanded.");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"Output tensor of expand op."
|
|
|
|
|
"The rank of Out is same as X except that each dimension size "
|
|
|
|
|
"of Out equals to corresponding dimension size of X multiplying "
|
|
|
|
|
"corresponding value of expandTimes.");
|
|
|
|
|
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
|
|
|
|
|
"The rank of Output(Out) is same as Input(X) except that each "
|
|
|
|
|
"dimension size of Output(Out) is equal to corresponding "
|
|
|
|
|
"dimension size of Input(X) multiplying corresponding value of "
|
|
|
|
|
"Attr(expandTimes).");
|
|
|
|
|
AddAttr<std::vector<int>>("expandTimes",
|
|
|
|
|
"Expand times number for each dimension.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Expand operator tiles the input by given times number. You should set times
|
|
|
|
|
number for each dimension by providing attribute 'expandTimes'. The rank of X
|
|
|
|
|
should be between in 1 and 6. Please notice that size of 'expandTimes' must be
|
|
|
|
|
same with X's rank.
|
|
|
|
|
should be in [1, 6]. Please notice that size of 'expandTimes' must be same with
|
|
|
|
|
X's rank.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -75,25 +78,27 @@ class ExpandGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
|
|
|
|
|
auto out_dims =
|
|
|
|
|
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->dims();
|
|
|
|
|
auto* x_grad =
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int> expand_times =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("expandTimes");
|
|
|
|
|
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < expand_times.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
|
|
|
|
|
"Each dimension size of Input(Out@GRAD) should be "
|
|
|
|
|
"equal to multiplication of crroresponding dimension "
|
|
|
|
|
"size of Input(X) and expandTimes value.");
|
|
|
|
|
"size of Input(X) and Attr(expandTimes) value.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (x_grad) x_grad->Resize(x_dims);
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, x_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|