|
|
|
@ -48,11 +48,11 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"The input Multip LoDTensors, which are variable-length "
|
|
|
|
|
"sequence or nested sequence.")
|
|
|
|
|
"(A vector of LoDTensor), the input is a vector of LoDTensor, "
|
|
|
|
|
"each of which is a variable-length sequence or nested sequence.")
|
|
|
|
|
.AsDuplicable();
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"A LoDTensor, the variable-length output of "
|
|
|
|
|
"(A LoDTensor), the variable-length output of "
|
|
|
|
|
"sequence_concat Op.");
|
|
|
|
|
AddAttr<int>("axis",
|
|
|
|
|
"(int, default 0)"
|
|
|
|
@ -61,27 +61,36 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddAttr<int>("level",
|
|
|
|
|
"(int, default 0)"
|
|
|
|
|
"The level which the inputs will be joined with."
|
|
|
|
|
"If level is 0, the inputs will be joined with "
|
|
|
|
|
"nested sequences."
|
|
|
|
|
"If level is 1, the inputs will be joined with sequences.")
|
|
|
|
|
"The level at which the inputs will be joined."
|
|
|
|
|
"If the level is 0, the inputs will be joined at the nested "
|
|
|
|
|
"sequence level."
|
|
|
|
|
"If the level is 1, the inputs will be joined at the "
|
|
|
|
|
"sequence level.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
The sequence_concat operator concatenates multiple LoDTensors.
|
|
|
|
|
It only supports sequences ( LoD Tensor with level=1)
|
|
|
|
|
or nested sequences (LoD tensor with level=0) as its inputs.
|
|
|
|
|
It only supports sequence (LoD Tensor with level number is 1)
|
|
|
|
|
or a nested sequence (LoD tensor with level number is 2) as its input.
|
|
|
|
|
- Case1:
|
|
|
|
|
If the axis is 1, level is 1, the LoD of Inputs are the same,
|
|
|
|
|
LoD(x0) = {{0,2,4},{0,1,2,3,4}}; Dims(x0) = (2,3,4)
|
|
|
|
|
LoD(x1) = {{0,2,4},{0,1,2,3,4}}; Dims(x1) = (2,4,4)
|
|
|
|
|
LoD(Out) = {{0,2,4},{0,1,2,3,4}}; Dims(Out) = (2,7,4)
|
|
|
|
|
If the axis is other than 0(here, axis is 1 and level is 1),
|
|
|
|
|
each input should have the same LoD information and the LoD
|
|
|
|
|
information of the output keeps the same as the input.
|
|
|
|
|
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
|
|
|
|
|
LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4)
|
|
|
|
|
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4)
|
|
|
|
|
- Case2:
|
|
|
|
|
If the axis is 0, level is 1, the LoD of inputs are different,
|
|
|
|
|
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (2,3,4)
|
|
|
|
|
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (3,3,4)
|
|
|
|
|
LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}; Dims(Out) = (5,3,4)
|
|
|
|
|
|
|
|
|
|
NOTE: The level of all the inputs should be the same.
|
|
|
|
|
If the axis is 0(here, leve is 0), the inputs are concatenated along
|
|
|
|
|
time steps, the LoD information of the output need to re-compute.
|
|
|
|
|
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
|
|
|
|
|
LoD(x1) = {{0,3,5}, {0,1,2,3,5}}; Dims(x1) = (5,3,4)
|
|
|
|
|
LoD(Out) = {{0,5,9}, {0,1,2,3,4,5,6,7,9}}; Dims(Out) = (9,3,4)
|
|
|
|
|
- Case3:
|
|
|
|
|
If the axis is 0(here, level is 1).
|
|
|
|
|
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
|
|
|
|
|
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4)
|
|
|
|
|
LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4)
|
|
|
|
|
|
|
|
|
|
NOTE: The levels of all the inputs should be the same.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -95,7 +104,7 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"The gradient of Out should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
|
|
|
|
|
"The gradient of X should not be empty.");
|
|
|
|
|
"The gradient of X should not be null.");
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|