|
|
|
@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/ctc_decode_op.h"
|
|
|
|
|
#include "paddle/operators/ctc_align_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class CTCDecodeOp : public framework::OperatorWithKernel {
|
|
|
|
|
class CTCAlignOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input of CTCDecodeOp should not be null.");
|
|
|
|
|
"Input of CTCAlignOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Output"),
|
|
|
|
|
"Output of CTCDecodeOp should not be null.");
|
|
|
|
|
"Output of CTCAlignOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("Input");
|
|
|
|
|
|
|
|
|
@ -42,14 +42,14 @@ class CTCDecodeOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("Input",
|
|
|
|
|
"(LodTensor, default: LoDTensor<int>), Its shape is "
|
|
|
|
|
"[Lp, 1], where Lp is the sum of all input sequences' length.");
|
|
|
|
|
AddOutput("Output", "(Tensor, default: Tensor<int>), The decode result.");
|
|
|
|
|
AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
|
|
|
|
|
AddAttr<int>("blank",
|
|
|
|
|
"(int, default: 0), the blank label setted in Connectionist "
|
|
|
|
|
"Temporal Classification (CTC) op.")
|
|
|
|
@ -59,7 +59,7 @@ class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"merge repeated elements between two blanks. ")
|
|
|
|
|
.SetDefault(true);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
CTCDecoder is used to merge repeated elements between two blanks
|
|
|
|
|
CTCAlign op is used to merge repeated elements between two blanks
|
|
|
|
|
and then delete all blanks in sequence.
|
|
|
|
|
|
|
|
|
|
Given:
|
|
|
|
@ -86,7 +86,7 @@ Then:
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker,
|
|
|
|
|
REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
ctc_decode, ops::CTCDecodeKernel<paddle::platform::CPUDeviceContext, int>);
|
|
|
|
|
ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>);
|