parent
6604d7cda2
commit
bb9d68dcb3
@ -0,0 +1,140 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/chunk_eval_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class ChunkEvalOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Inference"),
|
||||
"Input(Inference) of ChunkEvalOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
||||
"Input(Label) of ChunkEvalOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Precision"),
|
||||
"Output(Precision) of ChunkEvalOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Recall"),
|
||||
"Output(Recall) of ChunkEvalOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("F1-Score"),
|
||||
"Output(F1-Score) of ChunkEvalOp should not be null.");
|
||||
|
||||
auto inference_dim = ctx->GetInputDim("Inference");
|
||||
auto label_dim = ctx->GetInputDim("Label");
|
||||
|
||||
PADDLE_ENFORCE(inference_dim == label_dim,
|
||||
"Inference's shape must be the same as Label's shape.");
|
||||
|
||||
ctx->SetOutputDim("Precision", {1});
|
||||
ctx->SetOutputDim("Recall", {1});
|
||||
ctx->SetOutputDim("F1-Score", {1});
|
||||
}
|
||||
|
||||
framework::DataType IndicateDataType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::DataType::FP32;
|
||||
}
|
||||
};
|
||||
|
||||
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ChunkEvalOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Inference",
|
||||
"(Tensor, default: Tensor<int>) Predictions from the network.");
|
||||
AddInput("Label", "(Tensor, default: Tensor<int>) Labels of the data.");
|
||||
AddOutput(
|
||||
"Precision",
|
||||
"(float) The precision ratio of the predictions on current data.");
|
||||
AddOutput("Recall",
|
||||
"(float) The recall ratio of the predictions on current data.");
|
||||
AddOutput("F1-Score",
|
||||
"(float) The F1-Score of the predictions on current data.");
|
||||
AddAttr<int>("num_chunk_types", "(int) The number of chunk type.");
|
||||
AddAttr<std::string>("chunk_scheme",
|
||||
"(string, default IOB) The label scheme.")
|
||||
.SetDefault("IOB");
|
||||
AddAttr<std::vector<int>>(
|
||||
"excluded_chunk_types",
|
||||
"(list<int>) A list<int> indicating chunk types not to be counted.")
|
||||
.SetDefault(std::vector<int>{});
|
||||
AddComment(R"DOC(
|
||||
Chunk evaluator is used to evaluate segment labelling accuracy for a
|
||||
sequence. It calculates precision, recall and F1 scores for the chunk detection.
|
||||
To use chunk evaluator, several concepts need to be clarified firstly.
|
||||
[Chunk type] is the type of the whole chunk and a chunk consists of one or several words. (For example in NER, ORG for organization name, PER for person name etc.)
|
||||
[Tag type] indicates the position of a word in a chunk. (B for begin, I for inside, E for end, S for single)
|
||||
We can name a label by combining tag type and chunk type. (ie. B-ORG for begining of an organization name)
|
||||
The construction of label dictionary should obey the following rules:
|
||||
- Use one of the listed labelling schemes. These schemes differ in ways indicating chunk boundry.
|
||||
|
||||
Scheme Description
|
||||
plain Use the same label for the whole chunk.
|
||||
IOB Two labels for chunk type X, B-X for chunk begining and I-X for chunk inside.
|
||||
IOE Two labels for chunk type X, E-X for chunk ending and I-X for chunk inside.
|
||||
IOBES Four labels for chunk type X, B-X for chunk begining, I-X for chunk inside, E-X for chunk end and S-X for single word chunk.
|
||||
|
||||
To make it clear, let's illustrate by an NER example.
|
||||
Assuming that there are three named entity types including ORG, PER and LOC which are called 'chunk type' here,
|
||||
if 'IOB' scheme were used, the label set will be extended to a set including B-ORG, I-ORG, B-PER, I-PER, B-LOC, I-LOC and O,
|
||||
in which B-ORG for begining of ORG and I-ORG for inside of ORG.
|
||||
Prefixes which are called 'tag type' here are added to chunk types and there are two tag types including B and I.
|
||||
Of course, the training data should be labeled accordingly.
|
||||
- Mapping is done correctly by the listed equations and assigning protocol.
|
||||
The following table are equations to extract tag type and chunk type from a label.
|
||||
|
||||
tagType = label % numTagType
|
||||
chunkType = label / numTagType
|
||||
otherChunkType = numChunkTypes
|
||||
|
||||
The following table shows the mapping rule between tagType and tag type in each scheme.
|
||||
|
||||
Scheme Begin Inside End Single
|
||||
plain 0 - - -
|
||||
IOB 0 1 - -
|
||||
IOE - 0 1 -
|
||||
IOBES 0 1 2 3
|
||||
|
||||
Continue the NER example, and the label dict should look like this to satify above equations:
|
||||
|
||||
B-ORG 0
|
||||
I-ORG 1
|
||||
B-PER 2
|
||||
I-PER 3
|
||||
B-LOC 4
|
||||
I-LOC 5
|
||||
O 6
|
||||
|
||||
In this example, chunkType has three values: 0 for ORG, 1 for PER, 2 for LOC, because the scheme is
|
||||
"IOB" so tagType has two values: 0 for B and 1 for I.
|
||||
Here we will use I-LOC to explain the above mapping rules in detail.
|
||||
For I-LOC, the label id is 5, so we can get tagType=1 and chunkType=2, which means I-LOC is a part of NER chunk LOC
|
||||
and the tag is I.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(chunk_eval, ops::ChunkEvalOp,
|
||||
ops::ChunkEvalOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(chunk_eval,
|
||||
ops::ChunkEvalKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,219 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <set>
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class ChunkEvalKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
struct Segment {
|
||||
int begin;
|
||||
int end;
|
||||
int type;
|
||||
bool operator==(const Segment& y) const {
|
||||
return begin == y.begin && end == y.end && type == y.type;
|
||||
}
|
||||
};
|
||||
|
||||
void GetSegments(const int* label, int length, std::vector<Segment>& segments,
|
||||
int num_chunk_types, int num_tag_types, int other_chunk_type,
|
||||
int tag_begin, int tag_inside, int tag_end,
|
||||
int tag_single) const {
|
||||
segments.clear();
|
||||
segments.reserve(length);
|
||||
int chunk_start = 0;
|
||||
bool in_chunk = false;
|
||||
int tag = -1;
|
||||
int type = other_chunk_type;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
int prev_tag = tag;
|
||||
int prev_type = type;
|
||||
PADDLE_ENFORCE_LE(label[i], num_chunk_types * num_tag_types);
|
||||
tag = label[i] % num_tag_types;
|
||||
type = label[i] / num_tag_types;
|
||||
if (in_chunk && ChunkEnd(prev_tag, prev_type, tag, type, other_chunk_type,
|
||||
tag_begin, tag_inside, tag_end, tag_single)) {
|
||||
Segment segment{
|
||||
chunk_start, // begin
|
||||
i - 1, // end
|
||||
prev_type,
|
||||
};
|
||||
segments.push_back(segment);
|
||||
in_chunk = false;
|
||||
}
|
||||
if (ChunkBegin(prev_tag, prev_type, tag, type, other_chunk_type,
|
||||
tag_begin, tag_inside, tag_end, tag_single)) {
|
||||
chunk_start = i;
|
||||
in_chunk = true;
|
||||
}
|
||||
}
|
||||
if (in_chunk) {
|
||||
Segment segment{
|
||||
chunk_start, // begin
|
||||
length - 1, // end
|
||||
type,
|
||||
};
|
||||
segments.push_back(segment);
|
||||
}
|
||||
}
|
||||
|
||||
bool ChunkEnd(int prev_tag, int prev_type, int tag, int type,
|
||||
int other_chunk_type, int tag_begin, int tag_inside,
|
||||
int tag_end, int tag_single) const {
|
||||
if (prev_type == other_chunk_type) return false;
|
||||
if (type == other_chunk_type) return true;
|
||||
if (type != prev_type) return true;
|
||||
if (prev_tag == tag_begin) return tag == tag_begin || tag == tag_single;
|
||||
if (prev_tag == tag_inside) return tag == tag_begin || tag == tag_single;
|
||||
if (prev_tag == tag_end) return true;
|
||||
if (prev_tag == tag_single) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChunkBegin(int prev_tag, int prev_type, int tag, int type,
|
||||
int other_chunk_type, int tag_begin, int tag_inside,
|
||||
int tag_end, int tag_single) const {
|
||||
if (prev_type == other_chunk_type) return type != other_chunk_type;
|
||||
if (type == other_chunk_type) return false;
|
||||
if (type != prev_type) return true;
|
||||
if (tag == tag_begin) return true;
|
||||
if (tag == tag_inside) return prev_tag == tag_end || prev_tag == tag_single;
|
||||
if (tag == tag_end) return prev_tag == tag_end || prev_tag == tag_single;
|
||||
if (tag == tag_single) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
// initialize to parse configurations
|
||||
int num_chunk_types, num_tag_types;
|
||||
int other_chunk_type;
|
||||
int tag_begin, tag_inside, tag_end, tag_single;
|
||||
std::vector<Segment> label_segments;
|
||||
std::vector<Segment> output_segments;
|
||||
std::set<int> excluded_chunk_types;
|
||||
int64_t num_output_segments = 0;
|
||||
int64_t num_label_segments = 0;
|
||||
int64_t num_correct = 0;
|
||||
if (context.Attr<std::string>("chunk_scheme") == "IOB") {
|
||||
num_tag_types = 2;
|
||||
tag_begin = 0;
|
||||
tag_inside = 1;
|
||||
tag_end = -1;
|
||||
tag_single = -1;
|
||||
} else if (context.Attr<std::string>("chunk_scheme") == "IOE") {
|
||||
num_tag_types = 2;
|
||||
tag_begin = -1;
|
||||
tag_inside = 0;
|
||||
tag_end = 1;
|
||||
tag_single = -1;
|
||||
} else if (context.Attr<std::string>("chunk_scheme") == "IOBES") {
|
||||
num_tag_types = 4;
|
||||
tag_begin = 0;
|
||||
tag_inside = 1;
|
||||
tag_end = 2;
|
||||
tag_single = 3;
|
||||
} else if (context.Attr<std::string>("chunk_scheme") == "plain") {
|
||||
num_tag_types = 1;
|
||||
tag_begin = -1;
|
||||
tag_inside = -1;
|
||||
tag_end = -1;
|
||||
tag_single = -1;
|
||||
} else {
|
||||
PADDLE_THROW("Unknown chunk scheme.");
|
||||
}
|
||||
other_chunk_type = num_chunk_types = context.Attr<int>("num_chunk_types");
|
||||
excluded_chunk_types.insert(
|
||||
context.Attr<std::vector<int>>("excluded_chunk_types").begin(),
|
||||
context.Attr<std::vector<int>>("excluded_chunk_types").end());
|
||||
|
||||
auto* inference = context.Input<LoDTensor>("Inference");
|
||||
auto* label = context.Input<LoDTensor>("Label");
|
||||
auto* precision = context.Output<Tensor>("Precision");
|
||||
auto* recall = context.Output<Tensor>("Recall");
|
||||
auto* f1 = context.Output<Tensor>("F1-Score");
|
||||
|
||||
const int* inference_data = inference->data<int>();
|
||||
const int* label_data = label->data<int>();
|
||||
T* precision_data = precision->mutable_data<T>(context.GetPlace());
|
||||
T* racall_data = recall->mutable_data<T>(context.GetPlace());
|
||||
T* f1_data = f1->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto lod = label->lod();
|
||||
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
|
||||
PADDLE_ENFORCE(lod == inference->lod(),
|
||||
"LoD must be same between Inference and Label.");
|
||||
int num_sequences = lod[0].size() - 1;
|
||||
for (int i = 0; i < num_sequences; ++i) {
|
||||
int seq_length = lod[0][i + 1] - lod[0][i];
|
||||
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
|
||||
output_segments, label_segments, num_output_segments,
|
||||
num_label_segments, num_correct, num_chunk_types,
|
||||
num_tag_types, other_chunk_type, tag_begin, tag_inside,
|
||||
tag_end, tag_single, excluded_chunk_types);
|
||||
}
|
||||
*precision_data =
|
||||
!num_output_segments ? 0 : (T)num_correct / num_output_segments;
|
||||
*racall_data =
|
||||
!num_label_segments ? 0 : (T)num_correct / num_label_segments;
|
||||
*f1_data = !num_correct ? 0 : 2 * (*precision_data) * (*racall_data) /
|
||||
((*precision_data) + (*racall_data));
|
||||
}
|
||||
|
||||
void EvalOneSeq(const int* output, const int* label, int length,
|
||||
std::vector<Segment>& output_segments,
|
||||
std::vector<Segment>& label_segments,
|
||||
int64_t& num_output_segments, int64_t& num_label_segments,
|
||||
int64_t& num_correct, int num_chunk_types, int num_tag_types,
|
||||
int other_chunk_type, int tag_begin, int tag_inside,
|
||||
int tag_end, int tag_single,
|
||||
const std::set<int>& excluded_chunk_types) const {
|
||||
GetSegments(output, length, output_segments, num_chunk_types, num_tag_types,
|
||||
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
|
||||
GetSegments(label, length, label_segments, num_chunk_types, num_tag_types,
|
||||
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
|
||||
size_t i = 0, j = 0;
|
||||
while (i < output_segments.size() && j < label_segments.size()) {
|
||||
if (output_segments[i] == label_segments[j] &&
|
||||
excluded_chunk_types.count(output_segments[i].type) != 1) {
|
||||
++num_correct;
|
||||
}
|
||||
if (output_segments[i].end < label_segments[j].end) {
|
||||
++i;
|
||||
} else if (output_segments[i].end > label_segments[j].end) {
|
||||
++j;
|
||||
} else {
|
||||
++i;
|
||||
++j;
|
||||
}
|
||||
}
|
||||
for (auto& segment : label_segments) {
|
||||
if (excluded_chunk_types.count(segment.type) != 1) ++num_label_segments;
|
||||
}
|
||||
for (auto& segment : output_segments) {
|
||||
if (excluded_chunk_types.count(segment.type) != 1) ++num_output_segments;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,176 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class Segments(object):
|
||||
def __init__(self, chunk_type, start_idx, end_idx):
|
||||
self.chunk_type = chunk_type
|
||||
self.start_idx = start_idx
|
||||
self.end_idx = end_idx
|
||||
|
||||
def __str__(self):
|
||||
return '(Segments: %s, %s, %s)' % (self.chunk_type, self.start_idx,
|
||||
self.end_idx)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
class TestChunkEvalOp(OpTest):
|
||||
num_sequences = 5
|
||||
batch_size = 50
|
||||
|
||||
def parse_scheme(self):
|
||||
if self.scheme == 'IOB':
|
||||
self.num_tag_types = 2
|
||||
elif self.scheme == 'IOE':
|
||||
self.num_tag_types = 2
|
||||
|
||||
def fill_with_chunks(self, data, chunks):
|
||||
for chunk in chunks:
|
||||
if self.scheme == 'IOB':
|
||||
data[chunk.start_idx] = chunk.chunk_type * self.num_tag_types
|
||||
data[chunk.start_idx + 1:
|
||||
chunk.end_idx] = chunk.chunk_type * self.num_tag_types + (
|
||||
self.num_tag_types - 1)
|
||||
data[chunk.end_idx] = chunk.chunk_type * self.num_tag_types + (
|
||||
self.num_tag_types - 1
|
||||
) if chunk.start_idx < chunk.end_idx else data[chunk.start_idx]
|
||||
elif self.scheme == 'IOE':
|
||||
data[chunk.start_idx:
|
||||
chunk.end_idx] = chunk.chunk_type * self.num_tag_types
|
||||
data[chunk.end_idx] = chunk.chunk_type * self.num_tag_types + (
|
||||
self.num_tag_types - 1)
|
||||
|
||||
def rand_chunks(self, starts, num_chunks):
|
||||
if num_chunks < 0:
|
||||
num_chunks = np.random.randint(starts[-1])
|
||||
chunks = []
|
||||
# generate chunk beginnings
|
||||
chunk_begins = sorted(
|
||||
np.random.choice(
|
||||
range(starts[-1]), num_chunks, replace=False))
|
||||
seq_chunk_begins = []
|
||||
begin_idx = 0
|
||||
# divide chunks into sequences
|
||||
for i in range(len(starts) - 1):
|
||||
tmp_chunk_begins = []
|
||||
while begin_idx < len(chunk_begins) and chunk_begins[
|
||||
begin_idx] < starts[i + 1]:
|
||||
tmp_chunk_begins.append(chunk_begins[begin_idx])
|
||||
begin_idx += 1
|
||||
seq_chunk_begins.append(tmp_chunk_begins)
|
||||
# generate chunk ends
|
||||
chunk_ends = []
|
||||
for i in range(len(seq_chunk_begins)):
|
||||
for j in range(len(seq_chunk_begins[i])):
|
||||
low = seq_chunk_begins[i][j]
|
||||
high = seq_chunk_begins[i][j + 1] if j < len(seq_chunk_begins[
|
||||
i]) - 1 else starts[i + 1]
|
||||
chunk_ends.append(np.random.randint(low, high))
|
||||
# generate chunks
|
||||
for chunk_pos in zip(chunk_begins, chunk_ends):
|
||||
chunk_type = np.random.randint(self.num_chunk_types)
|
||||
chunks.append(Segments(chunk_type, *chunk_pos))
|
||||
return chunks
|
||||
|
||||
def gen_chunks(self, infer, label, starts):
|
||||
chunks = self.rand_chunks(starts,
|
||||
self.num_infer_chunks + self.num_label_chunks
|
||||
- self.num_correct_chunks)
|
||||
correct_chunks = np.random.choice(
|
||||
range(len(chunks)), self.num_correct_chunks, replace=False)
|
||||
infer_chunks = np.random.choice(
|
||||
[x for x in range(len(chunks)) if x not in correct_chunks],
|
||||
self.num_infer_chunks - self.num_correct_chunks,
|
||||
replace=False)
|
||||
infer_chunks = sorted(correct_chunks.tolist() + infer_chunks.tolist())
|
||||
label_chunks = np.random.choice(
|
||||
[x for x in range(len(chunks)) if x not in infer_chunks],
|
||||
self.num_label_chunks - self.num_correct_chunks,
|
||||
replace=False)
|
||||
label_chunks = sorted(correct_chunks.tolist() + label_chunks.tolist())
|
||||
self.fill_with_chunks(infer, [chunks[idx] for idx in infer_chunks])
|
||||
self.fill_with_chunks(label, [chunks[idx] for idx in label_chunks])
|
||||
# exclude types in excluded_chunk_types
|
||||
if len(self.excluded_chunk_types) > 0:
|
||||
for idx in correct_chunks:
|
||||
if chunks[idx].chunk_type in self.excluded_chunk_types:
|
||||
self.num_correct_chunks -= 1
|
||||
for idx in infer_chunks:
|
||||
if chunks[idx].chunk_type in self.excluded_chunk_types:
|
||||
self.num_infer_chunks -= 1
|
||||
for idx in label_chunks:
|
||||
if chunks[idx].chunk_type in self.excluded_chunk_types:
|
||||
self.num_label_chunks -= 1
|
||||
return self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks
|
||||
|
||||
def set_confs(self):
|
||||
# Use the IOB scheme and labels with 2 chunk types
|
||||
self.scheme = 'IOB'
|
||||
self.num_chunk_types = 2
|
||||
self.excluded_chunk_types = []
|
||||
self.other_chunk_type = self.num_chunk_types
|
||||
self.attrs = {
|
||||
'num_chunk_types': self.num_chunk_types,
|
||||
'chunk_scheme': self.scheme,
|
||||
'excluded_chunk_types': self.excluded_chunk_types
|
||||
}
|
||||
self.parse_scheme()
|
||||
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = 4, 5, 9
|
||||
|
||||
def set_data(self):
|
||||
infer = np.zeros((self.batch_size, )).astype("int32")
|
||||
infer.fill(self.num_chunk_types * self.num_tag_types)
|
||||
label = np.copy(infer)
|
||||
starts = np.random.choice(
|
||||
range(1, self.batch_size), self.num_sequences - 1,
|
||||
replace=False).tolist()
|
||||
starts.extend([0, self.batch_size])
|
||||
starts = sorted(starts)
|
||||
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = self.gen_chunks(
|
||||
infer, label, starts)
|
||||
self.inputs = {
|
||||
'Inference': (infer, [starts]),
|
||||
'Label': (label, [starts])
|
||||
}
|
||||
precision = float(
|
||||
self.num_correct_chunks
|
||||
) / self.num_infer_chunks if self.num_infer_chunks else 0
|
||||
recall = float(self.num_correct_chunks
|
||||
) / self.num_label_chunks if self.num_label_chunks else 0
|
||||
f1 = float(2 * precision * recall) / (
|
||||
precision + recall) if self.num_correct_chunks else 0
|
||||
self.outputs = {
|
||||
'Precision': [precision],
|
||||
'Recall': [recall],
|
||||
'F1-Score': [f1]
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = 'chunk_eval'
|
||||
self.set_confs()
|
||||
self.set_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestChunkEvalOpWithExclude(TestChunkEvalOp):
|
||||
def set_confs(self):
|
||||
# Use the IOE scheme and labels with 3 chunk types
|
||||
self.scheme = 'IOE'
|
||||
self.num_chunk_types = 3
|
||||
self.excluded_chunk_types = [1]
|
||||
self.other_chunk_type = self.num_chunk_types
|
||||
self.attrs = {
|
||||
'num_chunk_types': self.num_chunk_types,
|
||||
'chunk_scheme': self.scheme,
|
||||
'excluded_chunk_types': self.excluded_chunk_types
|
||||
}
|
||||
self.parse_scheme()
|
||||
self.num_correct_chunks, self.num_infer_chunks, self.num_label_chunks = 15, 18, 20
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue