|
|
@ -111,9 +111,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
std::vector<Segment> label_segments;
|
|
|
|
std::vector<Segment> label_segments;
|
|
|
|
std::vector<Segment> output_segments;
|
|
|
|
std::vector<Segment> output_segments;
|
|
|
|
std::set<int> excluded_chunk_types;
|
|
|
|
std::set<int> excluded_chunk_types;
|
|
|
|
int64_t num_output_segments = 0;
|
|
|
|
|
|
|
|
int64_t num_label_segments = 0;
|
|
|
|
|
|
|
|
int64_t num_correct = 0;
|
|
|
|
|
|
|
|
if (context.Attr<std::string>("chunk_scheme") == "IOB") {
|
|
|
|
if (context.Attr<std::string>("chunk_scheme") == "IOB") {
|
|
|
|
num_tag_types = 2;
|
|
|
|
num_tag_types = 2;
|
|
|
|
tag_begin = 0;
|
|
|
|
tag_begin = 0;
|
|
|
@ -151,12 +149,24 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
auto* precision = context.Output<Tensor>("Precision");
|
|
|
|
auto* precision = context.Output<Tensor>("Precision");
|
|
|
|
auto* recall = context.Output<Tensor>("Recall");
|
|
|
|
auto* recall = context.Output<Tensor>("Recall");
|
|
|
|
auto* f1 = context.Output<Tensor>("F1-Score");
|
|
|
|
auto* f1 = context.Output<Tensor>("F1-Score");
|
|
|
|
|
|
|
|
auto* num_infer_chunks = context.Output<Tensor>("NumInferChunks");
|
|
|
|
|
|
|
|
auto* num_label_chunks = context.Output<Tensor>("NumLabelChunks");
|
|
|
|
|
|
|
|
auto* num_correct_chunks = context.Output<Tensor>("NumCorrectChunks");
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t* inference_data = inference->data<int64_t>();
|
|
|
|
const int64_t* inference_data = inference->data<int64_t>();
|
|
|
|
const int64_t* label_data = label->data<int64_t>();
|
|
|
|
const int64_t* label_data = label->data<int64_t>();
|
|
|
|
T* precision_data = precision->mutable_data<T>(context.GetPlace());
|
|
|
|
T* precision_data = precision->mutable_data<T>(context.GetPlace());
|
|
|
|
T* racall_data = recall->mutable_data<T>(context.GetPlace());
|
|
|
|
T* racall_data = recall->mutable_data<T>(context.GetPlace());
|
|
|
|
T* f1_data = f1->mutable_data<T>(context.GetPlace());
|
|
|
|
T* f1_data = f1->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
int64_t* num_infer_chunks_data =
|
|
|
|
|
|
|
|
num_infer_chunks->mutable_data<int64_t>(context.GetPlace());
|
|
|
|
|
|
|
|
int64_t* num_label_chunks_data =
|
|
|
|
|
|
|
|
num_label_chunks->mutable_data<int64_t>(context.GetPlace());
|
|
|
|
|
|
|
|
int64_t* num_correct_chunks_data =
|
|
|
|
|
|
|
|
num_correct_chunks->mutable_data<int64_t>(context.GetPlace());
|
|
|
|
|
|
|
|
*num_infer_chunks_data = 0;
|
|
|
|
|
|
|
|
*num_label_chunks_data = 0;
|
|
|
|
|
|
|
|
*num_correct_chunks_data = 0;
|
|
|
|
|
|
|
|
|
|
|
|
auto lod = label->lod();
|
|
|
|
auto lod = label->lod();
|
|
|
|
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
|
|
|
|
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
|
|
|
@ -166,17 +176,23 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
|
|
|
|
for (int i = 0; i < num_sequences; ++i) {
|
|
|
|
for (int i = 0; i < num_sequences; ++i) {
|
|
|
|
int seq_length = lod[0][i + 1] - lod[0][i];
|
|
|
|
int seq_length = lod[0][i + 1] - lod[0][i];
|
|
|
|
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
|
|
|
|
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
|
|
|
|
output_segments, label_segments, num_output_segments,
|
|
|
|
output_segments, label_segments, *num_infer_chunks_data,
|
|
|
|
num_label_segments, num_correct, num_chunk_types,
|
|
|
|
*num_label_chunks_data, *num_correct_chunks_data,
|
|
|
|
num_tag_types, other_chunk_type, tag_begin, tag_inside,
|
|
|
|
num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
|
|
|
|
tag_end, tag_single, excluded_chunk_types);
|
|
|
|
tag_inside, tag_end, tag_single, excluded_chunk_types);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
*precision_data = !num_output_segments ? 0 : static_cast<T>(num_correct) /
|
|
|
|
*precision_data = !(*num_infer_chunks_data)
|
|
|
|
num_output_segments;
|
|
|
|
? 0
|
|
|
|
*racall_data = !num_label_segments ? 0 : static_cast<T>(num_correct) /
|
|
|
|
: static_cast<T>(*num_correct_chunks_data) /
|
|
|
|
num_label_segments;
|
|
|
|
(*num_infer_chunks_data);
|
|
|
|
*f1_data = !num_correct ? 0 : 2 * (*precision_data) * (*racall_data) /
|
|
|
|
*racall_data = !(*num_label_chunks_data)
|
|
|
|
((*precision_data) + (*racall_data));
|
|
|
|
? 0
|
|
|
|
|
|
|
|
: static_cast<T>(*num_correct_chunks_data) /
|
|
|
|
|
|
|
|
(*num_label_chunks_data);
|
|
|
|
|
|
|
|
*f1_data = !(*num_correct_chunks_data)
|
|
|
|
|
|
|
|
? 0
|
|
|
|
|
|
|
|
: 2 * (*precision_data) * (*racall_data) /
|
|
|
|
|
|
|
|
((*precision_data) + (*racall_data));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void EvalOneSeq(const int64_t* output, const int64_t* label, int length,
|
|
|
|
void EvalOneSeq(const int64_t* output, const int64_t* label, int length,
|
|
|
|