|
|
|
@ -93,21 +93,21 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
out_t->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto out = out_t->data<T>();
|
|
|
|
|
|
|
|
|
|
std::vector<T> distance(num_strs, 0.0);
|
|
|
|
|
T distance = 0.0;
|
|
|
|
|
for (size_t num = 0; num < num_strs; num++) {
|
|
|
|
|
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
|
|
|
|
|
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
|
|
|
|
|
if (m == 0 || n == 0) {
|
|
|
|
|
distance[num] = std::max(m, n);
|
|
|
|
|
distance = std::max(m, n);
|
|
|
|
|
if (normalized) {
|
|
|
|
|
PADDLE_ENFORCE(n > 0,
|
|
|
|
|
"The reference string (#%d) cannot be empty "
|
|
|
|
|
"when Attr(normalized) is enabled.",
|
|
|
|
|
n);
|
|
|
|
|
distance[num] = distance[num] / n;
|
|
|
|
|
distance = distance / n;
|
|
|
|
|
}
|
|
|
|
|
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
|
|
|
|
|
platform::CPUPlace(), &distance[num], sizeof(T), stream);
|
|
|
|
|
platform::CPUPlace(), &distance, sizeof(T), stream);
|
|
|
|
|
} else {
|
|
|
|
|
framework::Tensor dist_t;
|
|
|
|
|
dist_t.Resize({m + 1, n + 1});
|
|
|
|
|