|
|
|
@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|
|
|
|
character_type, use_space_char)
|
|
|
|
|
|
|
|
|
|
def __call__(self, preds, label=None, *args, **kwargs):
|
|
|
|
|
if isinstance(preds, paddle.Tensor):
|
|
|
|
|
preds = preds.numpy()
|
|
|
|
|
# out = self.decode_preds(preds)
|
|
|
|
|
|
|
|
|
|
preds = F.softmax(preds, axis=2).numpy()
|
|
|
|
|
preds_idx = preds.argmax(axis=2)
|
|
|
|
|
preds_prob = preds.max(axis=2)
|
|
|
|
|
text = self.decode(preds_idx, preds_prob)
|
|
|
|
@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|
|
|
|
return dict_character
|
|
|
|
|
|
|
|
|
|
def decode_preds(self, preds):
|
|
|
|
|
probs = F.softmax(preds, axis=2).numpy()
|
|
|
|
|
probs_ind = np.argmax(probs, axis=2)
|
|
|
|
|
probs_ind = np.argmax(preds, axis=2)
|
|
|
|
|
|
|
|
|
|
B, N, _ = preds.shape
|
|
|
|
|
l = np.ones(B).astype(np.int64) * N
|
|
|
|
|
length = paddle.to_variable(l)
|
|
|
|
|
length = paddle.to_tensor(l)
|
|
|
|
|
out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length)
|
|
|
|
|
batch_res = [
|
|
|
|
|
x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy())
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
result_list = []
|
|
|
|
|
for sample_idx, ind, prob in zip(batch_res, probs_ind, probs):
|
|
|
|
|
for sample_idx, ind, prob in zip(batch_res, probs_ind, preds):
|
|
|
|
|
char_list = [self.character[idx] for idx in sample_idx]
|
|
|
|
|
valid_ind = np.where(ind != 0)[0]
|
|
|
|
|
if len(valid_ind) == 0:
|
|
|
|
@ -172,4 +172,4 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|
|
|
|
else:
|
|
|
|
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
|
|
|
|
% beg_or_end
|
|
|
|
|
return idx
|
|
|
|
|
return idx
|