|
|
|
@ -12,14 +12,14 @@
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/seq_expand_op.h"
|
|
|
|
|
#include "paddle/operators/sequence_expand_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
|
|
|
|
|
class SeqExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
class SequenceExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -35,25 +35,25 @@ class SeqExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
SeqExpandOpMaker(framework::OpProto* proto,
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
SequenceExpandOpMaker(framework::OpProto* proto,
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor or LoDTensor) The input(X) of this operator can be a "
|
|
|
|
|
"LoDTensor or a base Tensor.");
|
|
|
|
|
AddInput("Y",
|
|
|
|
|
"(LoDTensor)The reference input(Y) of seq_expand op."
|
|
|
|
|
"(LoDTensor)The reference input(Y) of sequence_expand op."
|
|
|
|
|
"It must be a LoDTensor with k-level(k>0)."
|
|
|
|
|
"The input(X) will be expanded according to LOD of input(Y)."
|
|
|
|
|
"The element numbers of last level in input(Y) "
|
|
|
|
|
"must be equal to dims[0] of input(X).");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(LodTensor)The output of seq_expand op."
|
|
|
|
|
"(LodTensor)The output of sequence_expand op."
|
|
|
|
|
"The lod of output will be as same as input(Y)'s lod.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Seq Expand Operator.
|
|
|
|
|
Sequence Expand Operator.
|
|
|
|
|
|
|
|
|
|
This operator expands input(X) according to LOD of input(Y).
|
|
|
|
|
Following are cases to better explain how this works:
|
|
|
|
@ -124,7 +124,7 @@ then we get 2-level LoDTensor
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SeqExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
class SequenceExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -146,11 +146,11 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
|
|
|
|
|
seq_expand_grad, ops::SeqExpandOpGrad);
|
|
|
|
|
REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
|
|
|
|
|
sequence_expand_grad, ops::SequenceExpandOpGrad);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
seq_expand,
|
|
|
|
|
ops::SeqExpandKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
sequence_expand,
|
|
|
|
|
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
seq_expand_grad,
|
|
|
|
|
ops::SeqExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
sequence_expand_grad,
|
|
|
|
|
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
|