diff --git a/model_zoo/official/cv/warpctc/src/dataset.py b/model_zoo/official/cv/warpctc/src/dataset.py index 12b75bd3bb..34c9cc8e84 100755 --- a/model_zoo/official/cv/warpctc/src/dataset.py +++ b/model_zoo/official/cv/warpctc/src/dataset.py @@ -67,6 +67,12 @@ def transpose_hwc2whc(image): return image +def transpose_hwc2chw(image): + """transpose image from HWC to CHW""" + image = np.transpose(image, (2, 0, 1)) + return image + + def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): """ create train or evaluation dataset for warpctc @@ -91,7 +97,10 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_ c.TypeCast(mstype.int32) ] ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) - ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8) + if device_target == 'Ascend': + ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8) + else: + ds = ds.map(operations=transpose_hwc2chw, input_columns=["image"], num_parallel_workers=8) ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8) ds = ds.batch(batch_size, drop_remainder=True)