|
|
@ -70,6 +70,7 @@ class BaseRecLabelDecode(object):
|
|
|
|
if text_index[batch_idx][idx] in ignored_tokens:
|
|
|
|
if text_index[batch_idx][idx] in ignored_tokens:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
if is_remove_duplicate:
|
|
|
|
if is_remove_duplicate:
|
|
|
|
|
|
|
|
# only for predict
|
|
|
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
|
|
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
|
|
|
batch_idx][idx]:
|
|
|
|
batch_idx][idx]:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|
|
|
text = self.decode(preds_idx, preds_prob)
|
|
|
|
text = self.decode(preds_idx, preds_prob)
|
|
|
|
if label is None:
|
|
|
|
if label is None:
|
|
|
|
return text
|
|
|
|
return text
|
|
|
|
label = self.decode(label)
|
|
|
|
label = self.decode(label, is_remove_duplicate=False)
|
|
|
|
return text, label
|
|
|
|
return text, label
|
|
|
|
|
|
|
|
|
|
|
|
def add_special_char(self, dict_character):
|
|
|
|
def add_special_char(self, dict_character):
|
|
|
|