embed python compose op

Leaf ops do self-reset.

embed python compose op

add implicit compose op

more tests
pull/6301/head
nhussain 4 years ago
parent 12f3665167
commit fda9462682

@ -37,6 +37,9 @@ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp
from mindspore._c_expression import typing
from mindspore import log as logger
import mindspore.dataset.transforms.py_transforms as py_transforms
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
@ -406,7 +409,7 @@ class Dataset:
return dataset
@check_map
def map(self, operations=None, input_columns=None, output_columns=None, column_order=None,
def map(self, operations, input_columns=None, output_columns=None, column_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
"""
Apply each operation in operations to this dataset.
@ -427,7 +430,7 @@ class Dataset:
Args:
operations (Union[list[TensorOp], list[functions]]): List of operations to be
applied on the dataset. Operations are applied in the order they appear in this list.
input_columns (list[str]): List of the names of the columns that will be passed to
input_columns (list[str], optional): List of the names of the columns that will be passed to
the first operation as input. The size of this list must match the number of
input columns expected by the first operator. (default=None, the first
operation will be passed however many columns that is required, starting from
@ -2021,8 +2024,25 @@ class MapDataset(DatasetOp):
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
super().__init__(num_parallel_workers)
self.children.append(input_dataset)
if operations is not None and not isinstance(operations, list):
operations = [operations]
if operations is not None:
if not isinstance(operations, list):
operations = [operations]
elif isinstance(operations, list) and len(operations) > 1:
# wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations
new_ops, start_ind, end_ind = [], 0, 0
for i, op in enumerate(operations):
if not callable(op):
# reset counts
if start_ind != end_ind:
new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
new_ops.append(op)
start_ind, end_ind = i + 1, i + 1
else:
end_ind += 1
# do additional check in case the last operation is a Python operation
if start_ind != end_ind:
new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
operations = new_ops
self.operations = operations
if input_columns is not None and not isinstance(input_columns, list):
input_columns = [input_columns]

@ -87,6 +87,38 @@ class Compose:
>>> py_vision.RandomErasing()])
>>> # apply the transform to the dataset through dataset.map()
>>> dataset = dataset.map(operations=transform, input_columns="image")
>>>
>>> # Compose is also be invoked implicitly, by just passing in a list of ops
>>> # the above example then becomes:
>>> transform_list = [py_vision.Decode(),
>>> py_vision.RandomHorizontalFlip(0.5),
>>> py_vision.ToTensor(),
>>> py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
>>> py_vision.RandomErasing()]
>>>
>>> # apply the transform to the dataset through dataset.map()
>>> dataset = dataset.map(operations=transform, input_columns="image")
>>>
>>> # Certain C++ and Python ops can be combined, but not all of them
>>> # An example of combined operations
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.transforms.c_transforms as c_transforms
>>> import mindspore.dataset.vision.c_transforms as c_vision
>>>
>>> data = ds.NumpySlicesDataset(arr, column_names=["cols"], shuffle=False)
>>> transformed_list = [py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]
>>> data = data.map(operations=op_list, input_columns=["cols"])
>>>
>>> # Here is an example of mixing vision ops
>>> data_dir = "/path/to/imagefolder_directory"
>>> data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
>>> input_columns = ["column_names"]
>>> data1 = data1.map(operations=op_list, input_columns=input_columns)
>>> op_list=[c_vision.Decode(),
>>> c_vision.Resize((224, 244)),
>>> py_vision.ToPIL(),
>>> np.array, # need to convert PIL image to a NumPy array to pass it to C++ operation
>>> c_vision.Resize((24, 24))]
"""
@check_compose_list
@ -94,14 +126,14 @@ class Compose:
self.transforms = transforms
@check_compose_call
def __call__(self, img):
def __call__(self, *args):
"""
Call method.
Returns:
lambda function, Lambda function that takes in an img to apply transformations on.
lambda function, Lambda function that takes in an args to apply transformations on.
"""
return util.compose(img, self.transforms)
return util.compose(self.transforms, *args)
class RandomApply:

@ -21,7 +21,17 @@ import numpy as np
from ..core.py_util_helpers import is_numpy
def compose(img, transforms):
def all_numpy(args):
""" for multi-input lambdas"""
if isinstance(args, tuple):
for value in args:
if not is_numpy(value):
return False
return True
return is_numpy(args)
def compose(transforms, *args):
"""
Compose a list of transforms and apply on the image.
@ -32,13 +42,15 @@ def compose(img, transforms):
Returns:
img (numpy.ndarray), An augmented image in Numpy ndarray.
"""
if is_numpy(img):
if all_numpy(args):
for transform in transforms:
img = transform(img)
if is_numpy(img):
return img
raise TypeError('img should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(img)))
raise TypeError('img should be Numpy ndarray. Got {}.'.format(type(img)))
args = transform(*args)
args = (args,) if not isinstance(args, tuple) else args
if all_numpy(args):
return args
raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(args)))
raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args)))
def one_hot_encoding(label, num_classes, epsilon):

@ -213,6 +213,9 @@ def check_compose_list(method):
type_check(transforms, (list,), transforms)
if not transforms:
raise ValueError("transforms list is empty.")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
raise ValueError("transforms[{}] is not callable.".format(i))
return method(self, *args, **kwargs)
return new_method
@ -225,11 +228,10 @@ def check_compose_call(method):
def new_method(self, *args, **kwargs):
sig = inspect.signature(method)
ba = sig.bind_partial(method, *args, **kwargs)
img = ba.arguments.get("img")
img = ba.arguments.get("args")
if img is None:
raise TypeError(
"Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).")
return method(self, *args, **kwargs)
return new_method
@ -243,6 +245,10 @@ def check_random_apply(method):
[transforms, prob], _ = parse_user_args(method, *args, **kwargs)
type_check(transforms, (list,), "transforms")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
raise ValueError("transforms[{}] is not callable.".format(i))
if prob is not None:
type_check(prob, (float, int,), "prob")
check_value(prob, [0., 1.], "prob")
@ -260,7 +266,9 @@ def check_transforms_list(method):
[transforms], _ = parse_user_args(method, *args, **kwargs)
type_check(transforms, (list,), "transforms")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
raise ValueError("transforms[{}] is not callable.".format(i))
return method(self, *args, **kwargs)
return new_method

@ -623,9 +623,9 @@ class FiveCrop:
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>>
>>> Compose([py_vision.Decode(),
>>> py_vision.FiveCrop(size),
>>> py_vision.FiveCrop(size=200),
>>> # 4D stack of 5 images
>>> lambda images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
>>> lambda *images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
"""
@check_crop
@ -663,9 +663,9 @@ class TenCrop:
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>>
>>> Compose([py_vision.Decode(),
>>> py_vision.TenCrop(size),
>>> py_vision.TenCrop(size=200),
>>> # 4D stack of 10 images
>>> lambda images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
>>> lambda *images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
"""
@check_ten_crop

File diff suppressed because it is too large Load Diff

@ -48,7 +48,7 @@ def test_five_crop_op(plot=False):
transforms_2 = [
vision.Decode(),
vision.FiveCrop(200),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
]
transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
data2 = data2.map(operations=transform_2, input_columns=["image"])
@ -91,7 +91,7 @@ def test_five_crop_error_msg():
with pytest.raises(RuntimeError) as info:
for _ in data:
pass
error_msg = "TypeError: img should be PIL image or NumPy array. Got <class 'tuple'>"
error_msg = "TypeError: __call__() takes 2 positional arguments but 6 were given"
# error msg comes from ToTensor()
assert error_msg in str(info.value)
@ -108,7 +108,7 @@ def test_five_crop_md5():
transforms = [
vision.Decode(),
vision.FiveCrop(100),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
]
transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
data = data.map(operations=transform, input_columns=["image"])

@ -250,6 +250,32 @@ def test_case_9():
i = i + 4
def test_pyfunc_implicit_compose():
"""
Test Implicit Compose with pyfunc
"""
logger.info("Test n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
col = ["col0", "col1"]
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(operations=[(lambda x, y: (x, x + y, x + y + 1)), (lambda x, y, z: (x, y, z))], input_columns=col,
output_columns=["out0", "out1", "out2"], column_order=["out0", "out1", "out2"])
i = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i, i + 1], [i + 2, i + 3]])
np.testing.assert_array_equal(item["out0"], golden)
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
np.testing.assert_array_equal(item["out1"], golden)
golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
np.testing.assert_array_equal(item["out2"], golden)
i = i + 4
def test_pyfunc_execption():
logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()")
@ -293,5 +319,6 @@ if __name__ == "__main__":
test_case_7()
test_case_8()
test_case_9()
test_pyfunc_implicit_compose()
test_pyfunc_execption()
skip_test_pyfunc_execption_multiprocess()

@ -46,7 +46,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
transforms_2 = [
vision.Decode(),
vision.TenCrop(crop_size, use_vertical_flip=vertical_flip),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
data2 = data2.map(operations=transform_2, input_columns=["image"])
@ -109,7 +109,7 @@ def test_ten_crop_md5():
transforms_2 = [
vision.Decode(),
vision.TenCrop((200, 100), use_vertical_flip=True),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
data2 = data2.map(operations=transform_2, input_columns=["image"])
@ -176,7 +176,7 @@ def test_ten_crop_wrong_img_error_msg():
with pytest.raises(RuntimeError) as info:
data.create_tuple_iterator(num_epochs=1).get_next()
error_msg = "TypeError: img should be PIL image or NumPy array. Got <class 'tuple'>"
error_msg = "TypeError: __call__() takes 2 positional arguments but 11 were given"
# error msg comes from ToTensor()
assert error_msg in str(info.value)

Loading…
Cancel
Save