|
|
|
@ -46,7 +46,11 @@ class LMDBReader(object):
|
|
|
|
|
if "tps" in params:
|
|
|
|
|
self.ues_tps = True
|
|
|
|
|
if "distort" in params:
|
|
|
|
|
self.use_distort = params['distort']
|
|
|
|
|
self.use_distort = params['distort'] and params['use_gpu']
|
|
|
|
|
if not params['use_gpu']:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Distort operation can only support in GPU. Distort will be set to False."
|
|
|
|
|
)
|
|
|
|
|
if params['mode'] == 'train':
|
|
|
|
|
self.batch_size = params['train_batch_size_per_card']
|
|
|
|
|
self.drop_last = True
|
|
|
|
@ -189,7 +193,11 @@ class SimpleReader(object):
|
|
|
|
|
if "tps" in params:
|
|
|
|
|
self.use_tps = True
|
|
|
|
|
if "distort" in params:
|
|
|
|
|
self.use_distort = params['distort']
|
|
|
|
|
self.use_distort = params['distort'] and params['use_gpu']
|
|
|
|
|
if not params['use_gpu']:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Distort operation can only support in GPU.Distort will be set to False."
|
|
|
|
|
)
|
|
|
|
|
if params['mode'] == 'train':
|
|
|
|
|
self.batch_size = params['train_batch_size_per_card']
|
|
|
|
|
self.drop_last = True
|
|
|
|
|