|
|
|
@ -14,6 +14,9 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
@ -36,11 +39,11 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void GetSegments(const int64_t* label, int length,
|
|
|
|
|
std::vector<Segment>& segments, int num_chunk_types,
|
|
|
|
|
std::vector<Segment>* segments, int num_chunk_types,
|
|
|
|
|
int num_tag_types, int other_chunk_type, int tag_begin,
|
|
|
|
|
int tag_inside, int tag_end, int tag_single) const {
|
|
|
|
|
segments.clear();
|
|
|
|
|
segments.reserve(length);
|
|
|
|
|
segments->clear();
|
|
|
|
|
segments->reserve(length);
|
|
|
|
|
int chunk_start = 0;
|
|
|
|
|
bool in_chunk = false;
|
|
|
|
|
int tag = -1;
|
|
|
|
@ -58,7 +61,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
|
i - 1, // end
|
|
|
|
|
prev_type,
|
|
|
|
|
};
|
|
|
|
|
segments.push_back(segment);
|
|
|
|
|
segments->push_back(segment);
|
|
|
|
|
in_chunk = false;
|
|
|
|
|
}
|
|
|
|
|
if (ChunkBegin(prev_tag, prev_type, tag, type, other_chunk_type,
|
|
|
|
@ -73,7 +76,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
|
length - 1, // end
|
|
|
|
|
type,
|
|
|
|
|
};
|
|
|
|
|
segments.push_back(segment);
|
|
|
|
|
segments->push_back(segment);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -177,8 +180,8 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int i = 0; i < num_sequences; ++i) {
|
|
|
|
|
int seq_length = lod[0][i + 1] - lod[0][i];
|
|
|
|
|
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
|
|
|
|
|
output_segments, label_segments, *num_infer_chunks_data,
|
|
|
|
|
*num_label_chunks_data, *num_correct_chunks_data,
|
|
|
|
|
&output_segments, &label_segments, num_infer_chunks_data,
|
|
|
|
|
num_label_chunks_data, num_correct_chunks_data,
|
|
|
|
|
num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
|
|
|
|
|
tag_inside, tag_end, tag_single, excluded_chunk_types);
|
|
|
|
|
}
|
|
|
|
@ -197,10 +200,10 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void EvalOneSeq(const int64_t* output, const int64_t* label, int length,
|
|
|
|
|
std::vector<Segment>& output_segments,
|
|
|
|
|
std::vector<Segment>& label_segments,
|
|
|
|
|
int64_t& num_output_segments, int64_t& num_label_segments,
|
|
|
|
|
int64_t& num_correct, int num_chunk_types, int num_tag_types,
|
|
|
|
|
std::vector<Segment>* output_segments,
|
|
|
|
|
std::vector<Segment>* label_segments,
|
|
|
|
|
int64_t* num_output_segments, int64_t* num_label_segments,
|
|
|
|
|
int64_t* num_correct, int num_chunk_types, int num_tag_types,
|
|
|
|
|
int other_chunk_type, int tag_begin, int tag_inside,
|
|
|
|
|
int tag_end, int tag_single,
|
|
|
|
|
const std::set<int>& excluded_chunk_types) const {
|
|
|
|
@ -209,25 +212,29 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
|
GetSegments(label, length, label_segments, num_chunk_types, num_tag_types,
|
|
|
|
|
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
|
|
|
|
|
size_t i = 0, j = 0;
|
|
|
|
|
while (i < output_segments.size() && j < label_segments.size()) {
|
|
|
|
|
if (output_segments[i] == label_segments[j] &&
|
|
|
|
|
excluded_chunk_types.count(output_segments[i].type) != 1) {
|
|
|
|
|
++num_correct;
|
|
|
|
|
while (i < output_segments->size() && j < label_segments->size()) {
|
|
|
|
|
if (output_segments->at(i) == label_segments->at(j) &&
|
|
|
|
|
excluded_chunk_types.count(output_segments->at(i).type) != 1) {
|
|
|
|
|
++(*num_correct);
|
|
|
|
|
}
|
|
|
|
|
if (output_segments[i].end < label_segments[j].end) {
|
|
|
|
|
if (output_segments->at(i).end < label_segments->at(j).end) {
|
|
|
|
|
++i;
|
|
|
|
|
} else if (output_segments[i].end > label_segments[j].end) {
|
|
|
|
|
} else if (output_segments->at(i).end > label_segments->at(j).end) {
|
|
|
|
|
++j;
|
|
|
|
|
} else {
|
|
|
|
|
++i;
|
|
|
|
|
++j;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto& segment : label_segments) {
|
|
|
|
|
if (excluded_chunk_types.count(segment.type) != 1) ++num_label_segments;
|
|
|
|
|
for (auto& segment : (*label_segments)) {
|
|
|
|
|
if (excluded_chunk_types.count(segment.type) != 1) {
|
|
|
|
|
++(*num_label_segments);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto& segment : output_segments) {
|
|
|
|
|
if (excluded_chunk_types.count(segment.type) != 1) ++num_output_segments;
|
|
|
|
|
for (auto& segment : (*output_segments)) {
|
|
|
|
|
if (excluded_chunk_types.count(segment.type) != 1) {
|
|
|
|
|
++(*num_output_segments);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|