diff --git a/model_zoo/official/cv/warpctc/src/dataset.py b/model_zoo/official/cv/warpctc/src/dataset.py index 11c3322f1e..98a91559e1 100755 --- a/model_zoo/official/cv/warpctc/src/dataset.py +++ b/model_zoo/official/cv/warpctc/src/dataset.py @@ -67,12 +67,6 @@ def transpose_hwc2whc(image): return image -def transpose_hwc2chw(image): - """transpose image from HWC to CHW""" - image = np.transpose(image, (2, 0, 1)) - return image - - def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): """ create train or evaluation dataset for warpctc @@ -93,14 +87,20 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_ vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), c.TypeCast(mstype.float16) ] + image_trans_gpu = [ + vc.Rescale(1.0 / 255.0, 0.0), + vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), + vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), + vc.HWC2CHW() + ] label_trans = [ c.TypeCast(mstype.int32) ] - data_set = data_set.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) if device_target == 'Ascend': + data_set = data_set.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) data_set = data_set.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8) else: - data_set = data_set.map(operations=transpose_hwc2chw, input_columns=["image"], num_parallel_workers=8) + data_set = data_set.map(operations=image_trans_gpu, input_columns=["image"], num_parallel_workers=8) data_set = data_set.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8) data_set = data_set.batch(batch_size, drop_remainder=True) diff --git a/model_zoo/official/cv/warpctc/src/warpctc.py b/model_zoo/official/cv/warpctc/src/warpctc.py index 98c70a450b..e80bef8365 100755 --- a/model_zoo/official/cv/warpctc/src/warpctc.py +++ b/model_zoo/official/cv/warpctc/src/warpctc.py @@ -123,7 +123,6 @@ class StackedRNNForGPU(nn.Cell): self.transpose = P.Transpose() def construct(self, x): - x = self.cast(x, mstype.float32) x = self.transpose(x, (3, 0, 2, 1)) x = self.reshape(x, (-1, self.batch_size, self.input_size)) output, _ = self.lstm(x, (self.h, self.c))