|
|
|
@ -25,10 +25,8 @@ class SeqExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of SeqExpandOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SeqExpandOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of SeqExpandOp should not be null while repeat == 0.");
|
|
|
|
@ -54,7 +52,7 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"The element numbers of last level in input('Y') "
|
|
|
|
|
"must be equal to dims[0] of input('X').");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"The output of seq_expand op."
|
|
|
|
|
"(LodTensor)The output of seq_expand op."
|
|
|
|
|
"The lod of output will be as same as input(Y)'s lod.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Expand input(X) according to LOD of input(Y).
|
|
|
|
@ -69,6 +67,7 @@ Given 2-level a LoDTensor input(X)
|
|
|
|
|
and input(Y)
|
|
|
|
|
Y.lod = [[0, 2, 4],
|
|
|
|
|
[0, 3, 6, 7, 8]]
|
|
|
|
|
with condition len(Y.lod[-1]) -1 == X.dims[0]
|
|
|
|
|
then we get 2-level LoDTensor
|
|
|
|
|
Out.lod = [[0, 2, 4],
|
|
|
|
|
[0, 3, 6, 7, 8]]
|
|
|
|
@ -83,6 +82,7 @@ Given a 0-level LoDTensor input(X)
|
|
|
|
|
X.dims = [3, 1]
|
|
|
|
|
and input(Y)
|
|
|
|
|
Y.lod = [[0, 2, 3, 6]]
|
|
|
|
|
with condition len(Y.lod[-1]) -1 == X.dims[0]
|
|
|
|
|
then we get 1-level LoDTensor
|
|
|
|
|
Out.lod = [[0, 2, 3, 6]]
|
|
|
|
|
Out.data = [a, a, b, c, c, c]
|
|
|
|
@ -96,11 +96,29 @@ Given a 0-level LoDTensor input(X)
|
|
|
|
|
X.dims = [3, 2]
|
|
|
|
|
and input(Y)
|
|
|
|
|
Y.lod = [[0, 2, 3, 6]]
|
|
|
|
|
with condition len(Y.lod[-1]) -1 == X.dims[0]
|
|
|
|
|
then we get 1-level LoDTensor
|
|
|
|
|
Out.lod = [[0, 2, 3, 6]]
|
|
|
|
|
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
|
|
|
|
|
Out.dims = [6, 2]
|
|
|
|
|
|
|
|
|
|
Case 4:
|
|
|
|
|
|
|
|
|
|
Given 2-level a LoDTensor input(X)
|
|
|
|
|
X.lod = [[0, 2, 3],
|
|
|
|
|
[0, 1, 3, 4]]
|
|
|
|
|
X.data = [a, b, c, d]
|
|
|
|
|
X.dims = [4, 1]
|
|
|
|
|
and input(Y)
|
|
|
|
|
Y.lod = [[0, 2, 4],
|
|
|
|
|
[0, 3, 6, 6, 8]]
|
|
|
|
|
with condition len(Y.lod[-1]) -1 == X.dims[0]
|
|
|
|
|
then we get 2-level LoDTensor
|
|
|
|
|
Out.lod = [[0, 2, 4],
|
|
|
|
|
[0, 3, 6, 6, 8]]
|
|
|
|
|
Out.data = [a, a, a, b, b, b, d, d]
|
|
|
|
|
Out.dims = [8, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
@ -112,8 +130,8 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|