diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 437dd09b0c..930bc37104 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -321,14 +321,16 @@ class Duplicate(cde.DuplicateOp): class Unique(cde.UniqueOp): """ - Return an output tensor containing all the unique elements of the input tensor in - the same order that they occur in the input tensor. + Perform the unique operation on the input tensor, only support transform one column each time. - Also return an index tensor that contains the index of each element of the - input tensor in the Unique output tensor. + Return 3 tensor: unique output tensor, index tensor, count tensor. - Finally, return a count tensor that contains the count of each element of - the output tensor in the input tensor. + Unique output tensor contains all the unique elements of the input tensor + in the same order that they occur in the input tensor. + + Index tensor that contains the index of each element of the input tensor in the unique output tensor. + + Count tensor that contains the count of each element of the output tensor in the input tensor. Note: Call batch op before calling this function. diff --git a/mindspore/mindrecord/tools/cifar10.py b/mindspore/mindrecord/tools/cifar10.py index 0ede2cc0ca..55c7ebf0fa 100644 --- a/mindspore/mindrecord/tools/cifar10.py +++ b/mindspore/mindrecord/tools/cifar10.py @@ -57,7 +57,13 @@ def restricted_loads(s): if isinstance(s, str): raise TypeError("can not load pickle from unicode string") f = io.BytesIO(s) - return RestrictedUnpickler(f, encoding='bytes').load() + try: + return RestrictedUnpickler(f, encoding='bytes').load() + except pickle.UnpicklingError: + raise RuntimeError("Not a valid Cifar10 Dataset.") + else: + raise RuntimeError("Unexpected error while Unpickling Cifar10 Dataset.") + class Cifar10: """ diff --git a/mindspore/mindrecord/tools/cifar100.py b/mindspore/mindrecord/tools/cifar100.py index 339025bb0a..f480dfa17d 100644 --- a/mindspore/mindrecord/tools/cifar100.py +++ b/mindspore/mindrecord/tools/cifar100.py @@ -56,7 +56,13 @@ def restricted_loads(s): if isinstance(s, str): raise TypeError("can not load pickle from unicode string") f = io.BytesIO(s) - return RestrictedUnpickler(f, encoding='bytes').load() + try: + return RestrictedUnpickler(f, encoding='bytes').load() + except pickle.UnpicklingError: + raise RuntimeError("Not a valid Cifar100 Dataset.") + else: + raise RuntimeError("Unexpected error while Unpickling Cifar100 Dataset.") + class Cifar100: """