|
|
|
@ -52,9 +52,16 @@ def main():
|
|
|
|
|
|
|
|
|
|
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
|
|
|
|
|
|
|
|
|
infer_shape = [3, -1, -1]
|
|
|
|
|
infer_shape = [3, -1, -1]
|
|
|
|
|
if config['Architecture']['model_type'] == "rec":
|
|
|
|
|
infer_shape = [3, 32, -1]
|
|
|
|
|
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
|
|
|
|
if 'Transform' in config['Architecture'] and config['Architecture'][
|
|
|
|
|
'Transform'] is not None and config['Architecture'][
|
|
|
|
|
'Transform']['name'] == 'TPS':
|
|
|
|
|
logger.info(
|
|
|
|
|
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
|
|
|
|
|
)
|
|
|
|
|
infer_shape[-1] = 100
|
|
|
|
|
|
|
|
|
|
model = to_static(
|
|
|
|
|
model,
|
|
|
|
|