|
|
@ -67,6 +67,12 @@ def transpose_hwc2whc(image):
|
|
|
|
return image
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transpose_hwc2chw(image):
|
|
|
|
|
|
|
|
"""transpose image from HWC to CHW"""
|
|
|
|
|
|
|
|
image = np.transpose(image, (2, 0, 1))
|
|
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
|
|
|
|
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
create train or evaluation dataset for warpctc
|
|
|
|
create train or evaluation dataset for warpctc
|
|
|
@ -91,7 +97,10 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
|
|
|
|
c.TypeCast(mstype.int32)
|
|
|
|
c.TypeCast(mstype.int32)
|
|
|
|
]
|
|
|
|
]
|
|
|
|
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
|
|
|
|
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
|
|
|
|
ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8)
|
|
|
|
if device_target == 'Ascend':
|
|
|
|
|
|
|
|
ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
ds = ds.map(operations=transpose_hwc2chw, input_columns=["image"], num_parallel_workers=8)
|
|
|
|
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
|
|
|
|
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
|
|
|
|
|
|
|
|
|
|
|
|
ds = ds.batch(batch_size, drop_remainder=True)
|
|
|
|
ds = ds.batch(batch_size, drop_remainder=True)
|
|
|
|