|
|
@ -12,12 +12,12 @@
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/ctc_edit_distance_op.h"
|
|
|
|
#include "paddle/operators/edit_distance_op.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
class CTCEditDistanceOp : public framework::OperatorWithKernel {
|
|
|
|
class EditDistanceOp : public framework::OperatorWithKernel {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
@ -29,17 +29,16 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
framework::OpKernelType GetKernelType(
|
|
|
|
framework::OpKernelType GetActualKernelType(
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
return framework::OpKernelType(framework::DataType::FP32,
|
|
|
|
return framework::OpKernelType(framework::proto::DataType::FP32,
|
|
|
|
ctx.device_context());
|
|
|
|
ctx.device_context());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
CTCEditDistanceOpMaker(framework::OpProto *proto,
|
|
|
|
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("X1",
|
|
|
|
AddInput("X1",
|
|
|
|
"(2-D tensor with shape [M x 1]) The indices for "
|
|
|
|
"(2-D tensor with shape [M x 1]) The indices for "
|
|
|
@ -54,10 +53,10 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
.SetDefault(false);
|
|
|
|
.SetDefault(false);
|
|
|
|
AddOutput("Out",
|
|
|
|
AddOutput("Out",
|
|
|
|
"(2-D tensor with shape [1 x 1]) "
|
|
|
|
"(2-D tensor with shape [1 x 1]) "
|
|
|
|
"The output distance of CTCEditDistance operator.");
|
|
|
|
"The output distance of EditDistance operator.");
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
|
|
CTCEditDistance operator computes the edit distance of two sequences, one named
|
|
|
|
EditDistance operator computes the edit distance of two sequences, one named
|
|
|
|
hypothesis with length M and another named reference with length N.
|
|
|
|
hypothesis with length M and another named reference with length N.
|
|
|
|
|
|
|
|
|
|
|
|
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
|
|
|
|
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
|
|
|
@ -80,8 +79,7 @@ reference string N.
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp,
|
|
|
|
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
|
|
|
|
ops::CTCEditDistanceOpMaker);
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
ctc_edit_distance,
|
|
|
|
edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
ops::CTCEditDistanceKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|