|
|
|
@ -47,20 +47,19 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto dist = dist_t.data<T>();
|
|
|
|
|
auto x1 = x1_t->data<T>();
|
|
|
|
|
auto x2 = x2_t->data<T>();
|
|
|
|
|
for (int i = 0; i < m + 1; ++i) {
|
|
|
|
|
dist[i * (n + 1)] = i; // dist[i][0] = i;
|
|
|
|
|
for (size_t i = 0; i < m + 1; ++i) {
|
|
|
|
|
dist[i * (n + 1)] = i;
|
|
|
|
|
}
|
|
|
|
|
for (int j = 0; j < n + 1; ++j) {
|
|
|
|
|
dist[j] = j; // dist[0][j] = j;
|
|
|
|
|
for (size_t j = 0; j < n + 1; ++j) {
|
|
|
|
|
dist[j] = j;
|
|
|
|
|
}
|
|
|
|
|
for (int i = 1; i < m + 1; ++i) {
|
|
|
|
|
for (int j = 1; j < n + 1; ++j) {
|
|
|
|
|
for (size_t i = 1; i < m + 1; ++i) {
|
|
|
|
|
for (size_t j = 1; j < n + 1; ++j) {
|
|
|
|
|
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
|
|
|
|
|
int deletions = dist[(i - 1) * (n + 1) + j] + 1;
|
|
|
|
|
int insertions = dist[i * (n + 1) + (j - 1)] + 1;
|
|
|
|
|
int substitutions = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
|
|
|
|
|
dist[i * (n + 1) + j] =
|
|
|
|
|
std::min(deletions, std::min(insertions, substitutions));
|
|
|
|
|
int dels = dist[(i - 1) * (n + 1) + j] + 1;
|
|
|
|
|
int ins = dist[i * (n + 1) + (j - 1)] + 1;
|
|
|
|
|
int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
|
|
|
|
|
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
distance = dist[m * (n + 1) + n];
|
|
|
|
|