|
|
|
@ -208,20 +208,46 @@ class ChunkEvaluator(Evaluator):
|
|
|
|
|
|
|
|
|
|
class EditDistance(Evaluator):
|
|
|
|
|
"""
|
|
|
|
|
Average edit distance error for multiple mini-batches.
|
|
|
|
|
Accumulate edit distance sum and sequence number from mini-batches and
|
|
|
|
|
compute the average edit_distance of all batches.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input: the sequences predicted by network
|
|
|
|
|
label: the target sequences which must has same sequence count
|
|
|
|
|
with input.
|
|
|
|
|
ignored_tokens(list of int): Tokens that should be removed before
|
|
|
|
|
calculating edit distance.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
exe = fluid.executor(place)
|
|
|
|
|
distance_evaluator = fluid.Evaluator.EditDistance(input, label)
|
|
|
|
|
for epoch in PASS_NUM:
|
|
|
|
|
distance_evaluator.reset(exe)
|
|
|
|
|
for data in batches:
|
|
|
|
|
loss, sum_distance = exe.run(fetch_list=[cost] + distance_evaluator.metrics)
|
|
|
|
|
avg_distance = distance_evaluator.eval(exe)
|
|
|
|
|
pass_distance = distance_evaluator.eval(exe)
|
|
|
|
|
|
|
|
|
|
In the above example:
|
|
|
|
|
'sum_distance' is the sum of the batch's edit distance.
|
|
|
|
|
'avg_distance' is the average of edit distance from the firt batch to the current batch.
|
|
|
|
|
'pass_distance' is the average of edit distance from all the pass.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, input, label, k=1, **kwargs):
|
|
|
|
|
def __init__(self, input, label, ignored_tokens=None, **kwargs):
|
|
|
|
|
super(EditDistance, self).__init__("edit_distance", **kwargs)
|
|
|
|
|
main_program = self.helper.main_program
|
|
|
|
|
if main_program.current_block().idx != 0:
|
|
|
|
|
raise ValueError("You can only invoke Evaluator in root block")
|
|
|
|
|
|
|
|
|
|
self.total_error = self.create_state(
|
|
|
|
|
dtype='float32', shape=[1], suffix='total')
|
|
|
|
|
dtype='float32', shape=[1], suffix='total_error')
|
|
|
|
|
self.seq_num = self.create_state(
|
|
|
|
|
dtype='int64', shape=[1], suffix='total')
|
|
|
|
|
error, seq_num = layers.edit_distance(input=input, label=label)
|
|
|
|
|
dtype='int64', shape=[1], suffix='seq_num')
|
|
|
|
|
error, seq_num = layers.edit_distance(
|
|
|
|
|
input=input, label=label, ignored_tokens=ignored_tokens)
|
|
|
|
|
#error = layers.cast(x=error, dtype='float32')
|
|
|
|
|
sum_error = layers.reduce_sum(error)
|
|
|
|
|
layers.sums(input=[self.total_error, sum_error], out=self.total_error)
|
|
|
|
|