Merge pull request #5300 from kuke/ctc_edit_distance_dev
Add edit distance operatoradd_depthwiseConv_op_gpu
commit
861b84f557
@ -0,0 +1,98 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/edit_distance_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class EditDistanceOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
|
||||
auto hyp_dims = ctx->GetInputDim("Hyps");
|
||||
auto ref_dims = ctx->GetInputDim("Refs");
|
||||
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
|
||||
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
|
||||
"equal to 1.");
|
||||
PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1,
|
||||
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
|
||||
"equal to 1.");
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(framework::proto::DataType::FP32,
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Hyps",
|
||||
"(2-D LoDTensor<int>, 2nd dim. equal to 1) "
|
||||
"The indices for hypothesis strings.");
|
||||
AddInput("Refs",
|
||||
"(2-D LoDTensor<int>, 2nd dim. equal to 1) "
|
||||
"The indices for reference strings.");
|
||||
AddAttr<bool>("normalized",
|
||||
"(bool, default false) Indicated whether to normalize "
|
||||
"the edit distance by the length of reference string.")
|
||||
.SetDefault(false);
|
||||
AddOutput("Out",
|
||||
"(2-D Tensor with shape [`batch_size` x 1]) "
|
||||
"The output edit distances of EditDistance operator.");
|
||||
AddComment(R"DOC(
|
||||
|
||||
EditDistance operator computes the edit distances between a batch of hypothesis
|
||||
strings and their references.
|
||||
|
||||
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
|
||||
are by counting the minimum number of operations to transform one string into anthor.
|
||||
Here the operations include insertion, deletion, and substitution. For example,
|
||||
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance
|
||||
is 3 for A will be transformed into B at least after two substitutions and one
|
||||
insertion:
|
||||
|
||||
"kitten" -> "sitten" -> "sittin" -> "sitting"
|
||||
|
||||
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
|
||||
number denoted by `batch_size`, and the separation is specified by the LoD information.
|
||||
And the `batch_size` reference strings are arranged in order in the same way in the
|
||||
LoDTensor Input(Refs).
|
||||
|
||||
Output(Out) contains the `batch_size` results and each stands for the edit stance
|
||||
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
|
||||
will be divided by the length of reference string.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,149 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
#include "paddle/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillFirstRow(T* dist, const int N) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (idx < N + 1) {
|
||||
dist[idx] = idx;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillFirstColumn(T* dist, const int M, const int N) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (idx < M + 1) {
|
||||
dist[idx * (N + 1)] = idx;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M,
|
||||
const int N, const int start) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int offset = N;
|
||||
int index = start + idx * offset;
|
||||
int row = index / (N + 1);
|
||||
int col = index % (N + 1);
|
||||
if (row > 0 && col > 0 && row < M + 1 && col < N + 1) {
|
||||
int cost = x1[row - 1] == x2[col - 1] ? 0 : 1;
|
||||
int dels = dist[(row - 1) * (N + 1) + col] + 1;
|
||||
int ins = dist[row * (N + 1) + col - 1] + 1;
|
||||
int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost;
|
||||
dist[index] = min(dels, min(ins, subs));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SetOutput(T* out, const T* dist, const int M, const int N,
|
||||
bool normalized) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (idx == 0) {
|
||||
out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Place, typename T>
|
||||
class EditDistanceGPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* out_t = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
|
||||
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
|
||||
|
||||
auto normalized = ctx.Attr<bool>("normalized");
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream();
|
||||
|
||||
auto hyp_lod = x1_t->lod()[0];
|
||||
auto ref_lod = x2_t->lod()[0];
|
||||
PADDLE_ENFORCE(
|
||||
hyp_lod.size() == ref_lod.size(),
|
||||
"Input(Hyps) and Input(Refs) must have the same batch size.");
|
||||
for (size_t i = 1; i < ref_lod.size(); ++i) {
|
||||
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
|
||||
"Reference string %d is empty.", i);
|
||||
}
|
||||
|
||||
auto num_strs = hyp_lod.size() - 1;
|
||||
out_t->Resize({static_cast<int64_t>(num_strs), 1});
|
||||
out_t->mutable_data<T>(ctx.GetPlace());
|
||||
auto out = out_t->data<T>();
|
||||
|
||||
T distance = 0.0;
|
||||
for (size_t num = 0; num < num_strs; num++) {
|
||||
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
|
||||
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
|
||||
if (m == 0 || n == 0) {
|
||||
distance = std::max(m, n);
|
||||
if (normalized) {
|
||||
PADDLE_ENFORCE(n > 0,
|
||||
"The reference string (#%d) cannot be empty "
|
||||
"when Attr(normalized) is enabled.",
|
||||
n);
|
||||
distance = distance / n;
|
||||
}
|
||||
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
|
||||
platform::CPUPlace(), &distance, sizeof(T), stream);
|
||||
} else {
|
||||
framework::Tensor dist_t;
|
||||
dist_t.Resize({m + 1, n + 1});
|
||||
dist_t.mutable_data<T>(ctx.GetPlace());
|
||||
auto dist = dist_t.data<T>();
|
||||
auto x1 = x1_t->data<int>() + hyp_lod[num];
|
||||
auto x2 = x2_t->data<int>() + ref_lod[num];
|
||||
|
||||
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
|
||||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);
|
||||
|
||||
FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS,
|
||||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n);
|
||||
// Compute the elements of distance matrix in the anti-diagonal diretion
|
||||
for (int64_t slice = 2; slice < m + n + 1; ++slice) {
|
||||
int z_m = slice < m + 1 ? 0 : slice - m;
|
||||
int z_n = slice < n + 1 ? 0 : slice - n;
|
||||
int size = slice - (z_m + z_n) + 1; // number of elments in the same
|
||||
// anti-diagonal line to update
|
||||
// the start index at which computes from
|
||||
int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1;
|
||||
Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS,
|
||||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2,
|
||||
m, n, start);
|
||||
}
|
||||
SetOutput<T><<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
edit_distance,
|
||||
ops::EditDistanceGPUKernel<paddle::platform::CUDAPlace, float>);
|
@ -0,0 +1,96 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class EditDistanceKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* out_t = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
|
||||
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
|
||||
|
||||
auto normalized = ctx.Attr<bool>("normalized");
|
||||
|
||||
auto hyp_lod = x1_t->lod()[0];
|
||||
auto ref_lod = x2_t->lod()[0];
|
||||
PADDLE_ENFORCE(
|
||||
hyp_lod.size() == ref_lod.size(),
|
||||
"Input(Hyps) and Input(Refs) must have the same batch size.");
|
||||
for (size_t i = 1; i < ref_lod.size(); ++i) {
|
||||
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
|
||||
"Reference string %d is empty.", i);
|
||||
}
|
||||
auto num_strs = hyp_lod.size() - 1;
|
||||
|
||||
out_t->Resize({static_cast<int64_t>(num_strs), 1});
|
||||
out_t->mutable_data<float>(ctx.GetPlace());
|
||||
auto out = out_t->data<T>();
|
||||
|
||||
T distance = 0.0;
|
||||
for (size_t num = 0; num < num_strs; ++num) {
|
||||
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
|
||||
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
|
||||
|
||||
if (m == 0) {
|
||||
distance = n;
|
||||
} else if (n == 0) {
|
||||
distance = m;
|
||||
} else {
|
||||
framework::Tensor dist_t;
|
||||
dist_t.Resize({m + 1, n + 1});
|
||||
dist_t.mutable_data<T>(ctx.GetPlace());
|
||||
auto dist = dist_t.data<T>();
|
||||
auto x1 = x1_t->data<int>() + hyp_lod[num];
|
||||
auto x2 = x2_t->data<int>() + ref_lod[num];
|
||||
for (int64_t i = 0; i < m + 1; ++i) {
|
||||
dist[i * (n + 1)] = i;
|
||||
}
|
||||
for (int64_t j = 0; j < n + 1; ++j) {
|
||||
dist[j] = j;
|
||||
}
|
||||
for (int64_t i = 1; i < m + 1; ++i) {
|
||||
for (int64_t j = 1; j < n + 1; ++j) {
|
||||
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
|
||||
int dels = dist[(i - 1) * (n + 1) + j] + 1;
|
||||
int ins = dist[i * (n + 1) + (j - 1)] + 1;
|
||||
int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
|
||||
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
|
||||
}
|
||||
}
|
||||
distance = dist[m * (n + 1) + n];
|
||||
}
|
||||
|
||||
if (normalized) {
|
||||
PADDLE_ENFORCE(n > 0,
|
||||
"The reference string (#%d) cannot be empty "
|
||||
"when Attr(normalized) is enabled.",
|
||||
n);
|
||||
distance = distance / n;
|
||||
}
|
||||
out[num] = distance;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,94 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def Levenshtein(hyp, ref):
|
||||
""" Compute the Levenshtein distance between two strings.
|
||||
|
||||
:param hyp: hypothesis string in index
|
||||
:type hyp: list
|
||||
:param ref: reference string in index
|
||||
:type ref: list
|
||||
"""
|
||||
m = len(hyp)
|
||||
n = len(ref)
|
||||
if m == 0:
|
||||
return n
|
||||
if n == 0:
|
||||
return m
|
||||
|
||||
dist = np.zeros((m + 1, n + 1)).astype("float32")
|
||||
for i in range(0, m + 1):
|
||||
dist[i][0] = i
|
||||
for j in range(0, n + 1):
|
||||
dist[0][j] = j
|
||||
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
cost = 0 if hyp[i - 1] == ref[j - 1] else 1
|
||||
deletion = dist[i - 1][j] + 1
|
||||
insertion = dist[i][j - 1] + 1
|
||||
substitution = dist[i - 1][j - 1] + cost
|
||||
dist[i][j] = min(deletion, insertion, substitution)
|
||||
return dist[m][n]
|
||||
|
||||
|
||||
class TestEditDistanceOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "edit_distance"
|
||||
normalized = False
|
||||
x1 = np.array([[0, 12, 3, 5, 8, 2]]).astype("int32")
|
||||
x2 = np.array([[0, 12, 4, 7, 8]]).astype("int32")
|
||||
x1 = np.transpose(x1)
|
||||
x2 = np.transpose(x2)
|
||||
x1_lod = [0, 1, 5]
|
||||
x2_lod = [0, 3, 4]
|
||||
|
||||
num_strs = len(x1_lod) - 1
|
||||
distance = np.zeros((num_strs, 1)).astype("float32")
|
||||
for i in range(0, num_strs):
|
||||
distance[i] = Levenshtein(
|
||||
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
|
||||
ref=x2[x2_lod[i]:x2_lod[i + 1]])
|
||||
if normalized is True:
|
||||
len_ref = x2_lod[i + 1] - x2_lod[i]
|
||||
distance[i] = distance[i] / len_ref
|
||||
self.attrs = {'normalized': normalized}
|
||||
self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
|
||||
self.outputs = {'Out': distance}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestEditDistanceOpNormalized(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "edit_distance"
|
||||
normalized = True
|
||||
x1 = np.array([[0, 10, 3, 6, 5, 8, 2]]).astype("int32")
|
||||
x2 = np.array([[0, 10, 4, 6, 7, 8]]).astype("int32")
|
||||
x1 = np.transpose(x1)
|
||||
x2 = np.transpose(x2)
|
||||
x1_lod = [0, 1, 3, 6]
|
||||
x2_lod = [0, 2, 3, 5]
|
||||
|
||||
num_strs = len(x1_lod) - 1
|
||||
distance = np.zeros((num_strs, 1)).astype("float32")
|
||||
for i in range(0, num_strs):
|
||||
distance[i] = Levenshtein(
|
||||
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
|
||||
ref=x2[x2_lod[i]:x2_lod[i + 1]])
|
||||
if normalized is True:
|
||||
len_ref = x2_lod[i + 1] - x2_lod[i]
|
||||
distance[i] = distance[i] / len_ref
|
||||
self.attrs = {'normalized': normalized}
|
||||
self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
|
||||
self.outputs = {'Out': distance}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue