|
|
|
@ -182,14 +182,15 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|
|
|
|
|
|
|
|
|
preds_prob = np.reshape(preds_prob, [-1, 25])
|
|
|
|
|
|
|
|
|
|
text = self.decode(preds_idx, preds_prob)
|
|
|
|
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
|
|
|
|
|
|
|
|
|
if label is None:
|
|
|
|
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
|
|
|
return text
|
|
|
|
|
label = self.decode(label, is_remove_duplicate=False)
|
|
|
|
|
label = self.decode(label, is_remove_duplicate=True)
|
|
|
|
|
return text, label
|
|
|
|
|
|
|
|
|
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=True):
|
|
|
|
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
|
|
|
""" convert text-index into text-label. """
|
|
|
|
|
result_list = []
|
|
|
|
|
ignored_tokens = self.get_ignored_tokens()
|
|
|
|
|