diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index cc99dc9052..f1bde116e6 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -37,6 +37,9 @@ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp from mindspore._c_expression import typing from mindspore import log as logger + +import mindspore.dataset.transforms.py_transforms as py_transforms + from . import samplers from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ @@ -406,7 +409,7 @@ class Dataset: return dataset @check_map - def map(self, operations=None, input_columns=None, output_columns=None, column_order=None, + def map(self, operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): """ Apply each operation in operations to this dataset. @@ -427,7 +430,7 @@ class Dataset: Args: operations (Union[list[TensorOp], list[functions]]): List of operations to be applied on the dataset. Operations are applied in the order they appear in this list. - input_columns (list[str]): List of the names of the columns that will be passed to + input_columns (list[str], optional): List of the names of the columns that will be passed to the first operation as input. The size of this list must match the number of input columns expected by the first operator. (default=None, the first operation will be passed however many columns that is required, starting from @@ -2021,8 +2024,25 @@ class MapDataset(DatasetOp): num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): super().__init__(num_parallel_workers) self.children.append(input_dataset) - if operations is not None and not isinstance(operations, list): - operations = [operations] + if operations is not None: + if not isinstance(operations, list): + operations = [operations] + elif isinstance(operations, list) and len(operations) > 1: + # wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations + new_ops, start_ind, end_ind = [], 0, 0 + for i, op in enumerate(operations): + if not callable(op): + # reset counts + if start_ind != end_ind: + new_ops.append(py_transforms.Compose(operations[start_ind:end_ind])) + new_ops.append(op) + start_ind, end_ind = i + 1, i + 1 + else: + end_ind += 1 + # do additional check in case the last operation is a Python operation + if start_ind != end_ind: + new_ops.append(py_transforms.Compose(operations[start_ind:end_ind])) + operations = new_ops self.operations = operations if input_columns is not None and not isinstance(input_columns, list): input_columns = [input_columns] diff --git a/mindspore/dataset/transforms/py_transforms.py b/mindspore/dataset/transforms/py_transforms.py index 0dc1445cdd..123e0f618a 100644 --- a/mindspore/dataset/transforms/py_transforms.py +++ b/mindspore/dataset/transforms/py_transforms.py @@ -87,6 +87,38 @@ class Compose: >>> py_vision.RandomErasing()]) >>> # apply the transform to the dataset through dataset.map() >>> dataset = dataset.map(operations=transform, input_columns="image") + >>> + >>> # Compose is also be invoked implicitly, by just passing in a list of ops + >>> # the above example then becomes: + >>> transform_list = [py_vision.Decode(), + >>> py_vision.RandomHorizontalFlip(0.5), + >>> py_vision.ToTensor(), + >>> py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), + >>> py_vision.RandomErasing()] + >>> + >>> # apply the transform to the dataset through dataset.map() + >>> dataset = dataset.map(operations=transform, input_columns="image") + >>> + >>> # Certain C++ and Python ops can be combined, but not all of them + >>> # An example of combined operations + >>> import mindspore.dataset as ds + >>> import mindspore.dataset.transforms.c_transforms as c_transforms + >>> import mindspore.dataset.vision.c_transforms as c_vision + >>> + >>> data = ds.NumpySlicesDataset(arr, column_names=["cols"], shuffle=False) + >>> transformed_list = [py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)] + >>> data = data.map(operations=op_list, input_columns=["cols"]) + >>> + >>> # Here is an example of mixing vision ops + >>> data_dir = "/path/to/imagefolder_directory" + >>> data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + >>> input_columns = ["column_names"] + >>> data1 = data1.map(operations=op_list, input_columns=input_columns) + >>> op_list=[c_vision.Decode(), + >>> c_vision.Resize((224, 244)), + >>> py_vision.ToPIL(), + >>> np.array, # need to convert PIL image to a NumPy array to pass it to C++ operation + >>> c_vision.Resize((24, 24))] """ @check_compose_list @@ -94,14 +126,14 @@ class Compose: self.transforms = transforms @check_compose_call - def __call__(self, img): + def __call__(self, *args): """ Call method. Returns: - lambda function, Lambda function that takes in an img to apply transformations on. + lambda function, Lambda function that takes in an args to apply transformations on. """ - return util.compose(img, self.transforms) + return util.compose(self.transforms, *args) class RandomApply: diff --git a/mindspore/dataset/transforms/py_transforms_util.py b/mindspore/dataset/transforms/py_transforms_util.py index d44ad4de40..bc331e0467 100644 --- a/mindspore/dataset/transforms/py_transforms_util.py +++ b/mindspore/dataset/transforms/py_transforms_util.py @@ -21,7 +21,17 @@ import numpy as np from ..core.py_util_helpers import is_numpy -def compose(img, transforms): +def all_numpy(args): + """ for multi-input lambdas""" + if isinstance(args, tuple): + for value in args: + if not is_numpy(value): + return False + return True + return is_numpy(args) + + +def compose(transforms, *args): """ Compose a list of transforms and apply on the image. @@ -32,13 +42,15 @@ def compose(img, transforms): Returns: img (numpy.ndarray), An augmented image in Numpy ndarray. """ - if is_numpy(img): + if all_numpy(args): for transform in transforms: - img = transform(img) - if is_numpy(img): - return img - raise TypeError('img should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(img))) - raise TypeError('img should be Numpy ndarray. Got {}.'.format(type(img))) + args = transform(*args) + args = (args,) if not isinstance(args, tuple) else args + + if all_numpy(args): + return args + raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(args))) + raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args))) def one_hot_encoding(label, num_classes, epsilon): diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index d317caf410..b31956d57b 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -213,6 +213,9 @@ def check_compose_list(method): type_check(transforms, (list,), transforms) if not transforms: raise ValueError("transforms list is empty.") + for i, transfrom in enumerate(transforms): + if not callable(transfrom): + raise ValueError("transforms[{}] is not callable.".format(i)) return method(self, *args, **kwargs) return new_method @@ -225,11 +228,10 @@ def check_compose_call(method): def new_method(self, *args, **kwargs): sig = inspect.signature(method) ba = sig.bind_partial(method, *args, **kwargs) - img = ba.arguments.get("img") + img = ba.arguments.get("args") if img is None: raise TypeError( "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).") - return method(self, *args, **kwargs) return new_method @@ -243,6 +245,10 @@ def check_random_apply(method): [transforms, prob], _ = parse_user_args(method, *args, **kwargs) type_check(transforms, (list,), "transforms") + for i, transfrom in enumerate(transforms): + if not callable(transfrom): + raise ValueError("transforms[{}] is not callable.".format(i)) + if prob is not None: type_check(prob, (float, int,), "prob") check_value(prob, [0., 1.], "prob") @@ -260,7 +266,9 @@ def check_transforms_list(method): [transforms], _ = parse_user_args(method, *args, **kwargs) type_check(transforms, (list,), "transforms") - + for i, transfrom in enumerate(transforms): + if not callable(transfrom): + raise ValueError("transforms[{}] is not callable.".format(i)) return method(self, *args, **kwargs) return new_method diff --git a/mindspore/dataset/vision/py_transforms.py b/mindspore/dataset/vision/py_transforms.py index fac75a2acc..9e752e92d4 100644 --- a/mindspore/dataset/vision/py_transforms.py +++ b/mindspore/dataset/vision/py_transforms.py @@ -623,9 +623,9 @@ class FiveCrop: >>> from mindspore.dataset.transforms.py_transforms import Compose >>> >>> Compose([py_vision.Decode(), - >>> py_vision.FiveCrop(size), + >>> py_vision.FiveCrop(size=200), >>> # 4D stack of 5 images - >>> lambda images: numpy.stack([py_vision.ToTensor()(image) for image in images])]) + >>> lambda *images: numpy.stack([py_vision.ToTensor()(image) for image in images])]) """ @check_crop @@ -663,9 +663,9 @@ class TenCrop: >>> from mindspore.dataset.transforms.py_transforms import Compose >>> >>> Compose([py_vision.Decode(), - >>> py_vision.TenCrop(size), + >>> py_vision.TenCrop(size=200), >>> # 4D stack of 10 images - >>> lambda images: numpy.stack([py_vision.ToTensor()(image) for image in images])]) + >>> lambda *images: numpy.stack([py_vision.ToTensor()(image) for image in images])]) """ @check_ten_crop diff --git a/tests/ut/data/dataset/golden/compose_c_py_1.npz b/tests/ut/data/dataset/golden/compose_c_py_1.npz new file mode 100644 index 0000000000..2749e7683d Binary files /dev/null and b/tests/ut/data/dataset/golden/compose_c_py_1.npz differ diff --git a/tests/ut/data/dataset/golden/compose_c_py_2.npz b/tests/ut/data/dataset/golden/compose_c_py_2.npz new file mode 100644 index 0000000000..eeafafc70b Binary files /dev/null and b/tests/ut/data/dataset/golden/compose_c_py_2.npz differ diff --git a/tests/ut/data/dataset/golden/compose_c_py_3.npz b/tests/ut/data/dataset/golden/compose_c_py_3.npz new file mode 100644 index 0000000000..95299850c6 Binary files /dev/null and b/tests/ut/data/dataset/golden/compose_c_py_3.npz differ diff --git a/tests/ut/python/dataset/test_compose.py b/tests/ut/python/dataset/test_compose.py index 7cb1e67521..688a9e6081 100644 --- a/tests/ut/python/dataset/test_compose.py +++ b/tests/ut/python/dataset/test_compose.py @@ -13,11 +13,19 @@ # limitations under the License. # ============================================================================== +import numpy as np import pytest import mindspore.common.dtype as mstype import mindspore.dataset as ds -import mindspore.dataset.transforms.c_transforms as ops -import mindspore.dataset.transforms.py_transforms as py_ops +import mindspore.dataset.transforms.c_transforms as c_transforms +import mindspore.dataset.transforms.py_transforms as py_transforms + +import mindspore.dataset.vision.c_transforms as c_vision +import mindspore.dataset.vision.py_transforms as py_vision + +from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers + +GENERATE_GOLDEN = False def test_compose(): @@ -38,35 +46,294 @@ def test_compose(): return str(e) # Test simple compose with only 1 op, this would generate a warning - assert test_config([[1, 0], [3, 4]], ops.Compose([ops.Fill(2)])) == [[2, 2], [2, 2]] + assert test_config([[1, 0], [3, 4]], c_transforms.Compose([c_transforms.Fill(2)])) == [[2, 2], [2, 2]] + # Test 1 column -> 2 columns -> 1 -> 2 -> 1 assert test_config([[1, 0]], - ops.Compose([ops.Duplicate(), ops.Concatenate(), ops.Duplicate(), ops.Concatenate()])) \ + c_transforms.Compose( + [c_transforms.Duplicate(), c_transforms.Concatenate(), c_transforms.Duplicate(), + c_transforms.Concatenate()])) \ == [[1, 0] * 4] - # Test one Python transform followed by a C transform. Type after OneHot is a float (mixed use-case) - assert test_config([1, 0], ops.Compose([py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)])) == [[[0, 1]], [[1, 0]]] + + # Test one Python transform followed by a C++ transform. Type after OneHot is a float (mixed use-case) + assert test_config([1, 0], + c_transforms.Compose([py_transforms.OneHotOp(2), c_transforms.TypeCast(mstype.int32)])) \ + == [[[0, 1]], [[1, 0]]] + # Test exceptions. with pytest.raises(TypeError) as error_info: - ops.Compose([1, ops.TypeCast(mstype.int32)]) + c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)]) assert "op_list[0] is not a c_transform op (TensorOp) nor a callable pyfunc." in str(error_info.value) + # Test empty op list with pytest.raises(ValueError) as error_info: - test_config([1, 0], ops.Compose([])) + test_config([1, 0], c_transforms.Compose([])) assert "op_list can not be empty." in str(error_info.value) # Test Python compose op - assert test_config([1, 0], py_ops.Compose([py_ops.OneHotOp(2)])) == [[[0, 1]], [[1, 0]]] - assert test_config([1, 0], py_ops.Compose([py_ops.OneHotOp(2), (lambda x: x + x)])) == [[[0, 2]], [[2, 0]]] + assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2)])) == [[[0, 1]], [[1, 0]]] + assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2), (lambda x: x + x)])) == [[[0, 2]], + [[2, 0]]] + # Test nested Python compose op assert test_config([1, 0], - py_ops.Compose([py_ops.Compose([py_ops.OneHotOp(2)]), (lambda x: x + x)])) \ + py_transforms.Compose([py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)])) \ == [[[0, 2]], [[2, 0]]] + # Test passing a list of Python ops without Compose wrapper + assert test_config([1, 0], + [py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)]) \ + == [[[0, 2]], [[2, 0]]] + assert test_config([1, 0], [py_transforms.OneHotOp(2), (lambda x: x + x)]) == [[[0, 2]], [[2, 0]]] + + # Test a non callable function + with pytest.raises(ValueError) as error_info: + py_transforms.Compose([1]) + assert "transforms[0] is not callable." in str(error_info.value) + + # Test empty Python op list + with pytest.raises(ValueError) as error_info: + test_config([1, 0], py_transforms.Compose([])) + assert "transforms list is empty." in str(error_info.value) + + # Pass in extra brackets with pytest.raises(TypeError) as error_info: - py_ops.Compose([(lambda x: x + x)])() + py_transforms.Compose([(lambda x: x + x)])() assert "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])())." in str( error_info.value) +def test_lambdas(): + """ + Test Multi Column Python Compose Op + """ + ds.config.set_seed(0) + + def test_config(arr, input_columns, output_cols, op_list): + data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False) + data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols, + column_order=output_cols) + res = [] + for i in data.create_dict_iterator(output_numpy=True): + for col_name in output_cols: + res.append(i[col_name].tolist()) + return res + + arr = ([[1]], [[3]]) + + assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([(lambda x, y: x)])) == [[1]] + assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([lambda x, y: x, lambda x: x])) == [[1]] + assert test_config(arr, ["col0", "col1"], ["a", "b"], + py_transforms.Compose([lambda x, y: x, lambda x: (x, x * 2)])) == \ + [[1], [2]] + assert test_config(arr, ["col0", "col1"], ["a", "b"], + [lambda x, y: (x, x + y), lambda x, y: (x, y * 2)]) == [[1], [8]] + + +def test_c_py_compose_transforms_module(): + """ + Test combining Python and C++ transforms + """ + ds.config.set_seed(0) + + def test_config(arr, input_columns, output_cols, op_list): + data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False) + data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols, + column_order=output_cols) + res = [] + for i in data.create_dict_iterator(output_numpy=True): + for col_name in output_cols: + res.append(i[col_name].tolist()) + return res + + arr = [1, 0] + assert test_config(arr, ["cols"], ["cols"], + [py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]) == \ + [[[False, True]], + [[True, False]]] + assert test_config(arr, ["cols"], ["cols"], + [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1)]) \ + == [[[1, 1]], [[1, 1]]] + assert test_config(arr, ["cols"], ["cols"], + [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1), (lambda x: x + x)]) \ + == [[[2, 2]], [[2, 2]]] + assert test_config([[1, 3]], ["cols"], ["cols"], + [c_transforms.PadEnd([3], -1), (lambda x: x + x)]) \ + == [[2, 6, -2]] + + arr = ([[1]], [[3]]) + assert test_config(arr, ["col0", "col1"], ["a"], [(lambda x, y: x + y), c_transforms.PadEnd([2], -1)]) == [[4, -1]] + + +def test_c_py_compose_vision_module(plot=False, run_golden=True): + """ + Test combining Python and C++ vision transforms + """ + original_seed = config_get_set_seed(10) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + def test_config(plot, file_name, op_list): + data_dir = "../data/dataset/testImageNetData/train/" + data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + data1 = data1.map(operations=op_list, input_columns=["image"]) + data2 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + data2 = data2.map(operations=c_vision.Decode(), input_columns=["image"]) + original_images = [] + transformed_images = [] + + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + transformed_images.append(item["image"]) + for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True): + original_images.append(item["image"]) + + if run_golden: + # Compare with expected md5 from images + save_and_check_md5(data1, file_name, generate_golden=GENERATE_GOLDEN) + + if plot: + visualize_list(original_images, transformed_images) + + test_config(op_list=[c_vision.Decode(), + py_vision.ToPIL(), + py_vision.Resize((224, 224)), + np.array], + plot=plot, file_name="compose_c_py_1.npz") + + test_config(op_list=[c_vision.Decode(), + c_vision.Resize((224, 244)), + py_vision.ToPIL(), + np.array, + c_vision.Resize((24, 24))], + plot=plot, file_name="compose_c_py_2.npz") + + test_config(op_list=[py_vision.Decode(), + py_vision.Resize((224, 224)), + np.array, + c_vision.RandomColor()], + plot=plot, file_name="compose_c_py_3.npz") + + # Restore configuration + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers((original_num_parallel_workers)) + + +def test_py_transforms_with_c_vision(): + """ + These examples will fail, as py_transforms.Random(Apply/Choice/Order) expect callable functions + """ + + ds.config.set_seed(0) + + def test_config(op_list): + data_dir = "../data/dataset/testImageNetData/train/" + data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + data = data.map(operations=op_list) + res = [] + for i in data.create_dict_iterator(output_numpy=True): + for col_name in output_cols: + res.append(i[col_name].tolist()) + return res + + with pytest.raises(ValueError) as error_info: + test_config(py_transforms.RandomApply([c_vision.Resize(200)])) + assert "transforms[0] is not callable." in str(error_info.value) + + with pytest.raises(ValueError) as error_info: + test_config(py_transforms.RandomChoice([c_vision.Resize(200)])) + assert "transforms[0] is not callable." in str(error_info.value) + + with pytest.raises(ValueError) as error_info: + test_config(py_transforms.RandomOrder([np.array, c_vision.Resize(200)])) + assert "transforms[1] is not callable." in str(error_info.value) + + with pytest.raises(RuntimeError) as error_info: + test_config([py_transforms.OneHotOp(20, 0.1)]) + assert "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()" in str( + error_info.value) + + +def test_py_vision_with_c_transforms(): + """ + Test combining Python vision operations with C++ transforms operations + """ + + ds.config.set_seed(0) + + def test_config(op_list): + data_dir = "../data/dataset/testImageNetData/train/" + data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + data1 = data1.map(operations=op_list, input_columns=["image"]) + transformed_images = [] + + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + transformed_images.append(item["image"]) + return transformed_images + + # Test with Mask Op + output_arr = test_config([py_vision.Decode(), + py_vision.CenterCrop((2)), np.array, + c_transforms.Mask(c_transforms.Relational.GE, 100)]) + + exp_arr = [np.array([[[True, False, False], + [True, False, False]], + [[True, False, False], + [True, False, False]]]), + np.array([[[True, False, False], + [True, False, False]], + [[True, False, False], + [True, False, False]]])] + + for exp_a, output in zip(exp_arr, output_arr): + np.testing.assert_array_equal(exp_a, output) + + # Test with Fill Op + output_arr = test_config([py_vision.Decode(), + py_vision.CenterCrop((4)), np.array, + c_transforms.Fill(10)]) + + exp_arr = [np.ones((4, 4, 3)) * 10] * 2 + for exp_a, output in zip(exp_arr, output_arr): + np.testing.assert_array_equal(exp_a, output) + + # Test with Concatenate Op, which will raise an error since ConcatenateOp only supports rank 1 tensors. + with pytest.raises(RuntimeError) as error_info: + test_config([py_vision.Decode(), + py_vision.CenterCrop((2)), np.array, + c_transforms.Concatenate(0)]) + assert "Only 1D tensors supported" in str(error_info.value) + + +def test_compose_with_custom_function(): + """ + Test Python Compose with custom function + """ + + def custom_function(x): + return (x, x * x) + + # First dataset + op_list = [ + lambda x: x * 3, + custom_function, + # convert two column output to one + lambda *images: np.stack(images) + ] + + data = ds.NumpySlicesDataset([[1, 2]], column_names=["col0"], shuffle=False) + data = data.map(input_columns=["col0"], operations=op_list) + # + + res = [] + for i in data.create_dict_iterator(output_numpy=True): + res.append(i["col0"].tolist()) + assert res == [[[3, 6], [9, 36]]] + + if __name__ == "__main__": test_compose() + test_lambdas() + test_c_py_compose_transforms_module() + test_c_py_compose_vision_module(plot=True) + test_py_transforms_with_c_vision() + test_py_vision_with_c_transforms() + test_compose_with_custom_function() diff --git a/tests/ut/python/dataset/test_five_crop.py b/tests/ut/python/dataset/test_five_crop.py index 156d76c2df..9a9176680f 100644 --- a/tests/ut/python/dataset/test_five_crop.py +++ b/tests/ut/python/dataset/test_five_crop.py @@ -48,7 +48,7 @@ def test_five_crop_op(plot=False): transforms_2 = [ vision.Decode(), vision.FiveCrop(200), - lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images + lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images ] transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2) data2 = data2.map(operations=transform_2, input_columns=["image"]) @@ -91,7 +91,7 @@ def test_five_crop_error_msg(): with pytest.raises(RuntimeError) as info: for _ in data: pass - error_msg = "TypeError: img should be PIL image or NumPy array. Got " + error_msg = "TypeError: __call__() takes 2 positional arguments but 6 were given" # error msg comes from ToTensor() assert error_msg in str(info.value) @@ -108,7 +108,7 @@ def test_five_crop_md5(): transforms = [ vision.Decode(), vision.FiveCrop(100), - lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images + lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images ] transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) data = data.map(operations=transform, input_columns=["image"]) diff --git a/tests/ut/python/dataset/test_pyfunc.py b/tests/ut/python/dataset/test_pyfunc.py index 5fa177223b..19b19f3179 100644 --- a/tests/ut/python/dataset/test_pyfunc.py +++ b/tests/ut/python/dataset/test_pyfunc.py @@ -250,6 +250,32 @@ def test_case_9(): i = i + 4 +def test_pyfunc_implicit_compose(): + """ + Test Implicit Compose with pyfunc + """ + logger.info("Test n-m PyFunc : lambda x, y : (x , x + 1, x + y)") + + col = ["col0", "col1"] + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(operations=[(lambda x, y: (x, x + y, x + y + 1)), (lambda x, y, z: (x, y, z))], input_columns=col, + output_columns=["out0", "out1", "out2"], column_order=["out0", "out1", "out2"]) + + i = 0 + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item["out0"], golden) + golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) + np.testing.assert_array_equal(item["out1"], golden) + golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]]) + np.testing.assert_array_equal(item["out2"], golden) + i = i + 4 + + def test_pyfunc_execption(): logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()") @@ -293,5 +319,6 @@ if __name__ == "__main__": test_case_7() test_case_8() test_case_9() + test_pyfunc_implicit_compose() test_pyfunc_execption() skip_test_pyfunc_execption_multiprocess() diff --git a/tests/ut/python/dataset/test_ten_crop.py b/tests/ut/python/dataset/test_ten_crop.py index 0b92b78d9c..809303183f 100644 --- a/tests/ut/python/dataset/test_ten_crop.py +++ b/tests/ut/python/dataset/test_ten_crop.py @@ -46,7 +46,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False): transforms_2 = [ vision.Decode(), vision.TenCrop(crop_size, use_vertical_flip=vertical_flip), - lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images + lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images ] transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2) data2 = data2.map(operations=transform_2, input_columns=["image"]) @@ -109,7 +109,7 @@ def test_ten_crop_md5(): transforms_2 = [ vision.Decode(), vision.TenCrop((200, 100), use_vertical_flip=True), - lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images + lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images ] transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2) data2 = data2.map(operations=transform_2, input_columns=["image"]) @@ -176,7 +176,7 @@ def test_ten_crop_wrong_img_error_msg(): with pytest.raises(RuntimeError) as info: data.create_tuple_iterator(num_epochs=1).get_next() - error_msg = "TypeError: img should be PIL image or NumPy array. Got " + error_msg = "TypeError: __call__() takes 2 positional arguments but 11 were given" # error msg comes from ToTensor() assert error_msg in str(info.value)