Add python wrapper for ctc_evaluator

fix-profile-doc-typo
wanghaoshuang 8 years ago
parent 144854d2e8
commit 0dd3919a21

@ -50,6 +50,7 @@ __all__ = [
'sequence_last_step',
'dropout',
'split',
'greedy_ctc_evaluator',
]
@ -1597,3 +1598,39 @@ def split(input, num_or_sections, dim=-1):
'axis': dim
})
return outs
def greedy_ctc_evaluator(input, label, blank, normalized=False, name=None):
"""
"""
helper = LayerHelper("greedy_ctc_evalutor", **locals())
# top 1 op
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": 1})
# ctc align op
ctc_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="ctc_align",
inputs={"Input": [topk_indices]},
outputs={"Out": [ctc_out]},
attrs={"merge_repeated": True,
"blank": blank})
# edit distance op
edit_distance_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="edit_distance",
inputs={"Hyps": [ctc_out],
"Refs": [label]},
outputs={"Out": [edit_distance_out]},
attrs={"normalized": normalized})
return edit_distance_out

Loading…
Cancel
Save