|
|
@ -410,7 +410,8 @@ def resize_norm_img_srn(img, image_shape):
|
|
|
|
|
|
|
|
|
|
|
|
def srn_other_inputs(image_shape,
|
|
|
|
def srn_other_inputs(image_shape,
|
|
|
|
num_heads,
|
|
|
|
num_heads,
|
|
|
|
max_text_length):
|
|
|
|
max_text_length,
|
|
|
|
|
|
|
|
char_num):
|
|
|
|
|
|
|
|
|
|
|
|
imgC, imgH, imgW = image_shape
|
|
|
|
imgC, imgH, imgW = image_shape
|
|
|
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
|
|
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
|
|
@ -418,7 +419,7 @@ def srn_other_inputs(image_shape,
|
|
|
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
|
|
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
|
|
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
|
|
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
|
|
|
|
|
|
|
|
|
|
|
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64')
|
|
|
|
lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64')
|
|
|
|
|
|
|
|
|
|
|
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
|
|
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
|
|
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
|
|
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
|
|
@ -441,17 +442,18 @@ def process_image_srn(img,
|
|
|
|
loss_type=None):
|
|
|
|
loss_type=None):
|
|
|
|
norm_img = resize_norm_img_srn(img, image_shape)
|
|
|
|
norm_img = resize_norm_img_srn(img, image_shape)
|
|
|
|
norm_img = norm_img[np.newaxis, :]
|
|
|
|
norm_img = norm_img[np.newaxis, :]
|
|
|
|
|
|
|
|
char_num = char_ops.get_char_num()
|
|
|
|
|
|
|
|
|
|
|
|
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
|
|
|
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
|
|
|
srn_other_inputs(image_shape, num_heads, max_text_length)
|
|
|
|
srn_other_inputs(image_shape, num_heads, max_text_length,char_num)
|
|
|
|
|
|
|
|
|
|
|
|
if label is not None:
|
|
|
|
if label is not None:
|
|
|
|
char_num = char_ops.get_char_num()
|
|
|
|
|
|
|
|
text = char_ops.encode(label)
|
|
|
|
text = char_ops.encode(label)
|
|
|
|
if len(text) == 0 or len(text) > max_text_length:
|
|
|
|
if len(text) == 0 or len(text) > max_text_length:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if loss_type == "srn":
|
|
|
|
if loss_type == "srn":
|
|
|
|
text_padded = [37] * max_text_length
|
|
|
|
text_padded = [int(char_num-1)] * max_text_length
|
|
|
|
for i in range(len(text)):
|
|
|
|
for i in range(len(text)):
|
|
|
|
text_padded[i] = text[i]
|
|
|
|
text_padded[i] = text[i]
|
|
|
|
lbl_weight[i] = [1.0]
|
|
|
|
lbl_weight[i] = [1.0]
|
|
|
|