dataset fixes: Update OneHot API docs; fixup Python UTs

pull/5385/head
Cathy Wong 5 years ago
parent 39e2791149
commit 7f6782be2a

@ -36,7 +36,7 @@ from .. import callback
def check_imagefolderdatasetv2(method): def check_imagefolderdatasetv2(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2).""" """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDatasetV2)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -62,7 +62,7 @@ def check_imagefolderdatasetv2(method):
def check_mnist_cifar_dataset(method): def check_mnist_cifar_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -85,7 +85,7 @@ def check_mnist_cifar_dataset(method):
def check_manifestdataset(method): def check_manifestdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -112,7 +112,7 @@ def check_manifestdataset(method):
def check_tfrecorddataset(method): def check_tfrecorddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -138,7 +138,7 @@ def check_tfrecorddataset(method):
def check_vocdataset(method): def check_vocdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(VOCDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -179,7 +179,7 @@ def check_vocdataset(method):
def check_cocodataset(method): def check_cocodataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CocoDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -215,7 +215,7 @@ def check_cocodataset(method):
def check_celebadataset(method): def check_celebadataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CelebADataset).""" """A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -251,7 +251,7 @@ def check_celebadataset(method):
def check_save(method): def check_save(method):
"""A wrapper that wrap a parameter checker to the save op.""" """A wrapper that wraps a parameter checker around the saved operator."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -271,7 +271,7 @@ def check_save(method):
def check_minddataset(method): def check_minddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -303,7 +303,7 @@ def check_minddataset(method):
def check_generatordataset(method): def check_generatordataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -369,7 +369,7 @@ def check_generatordataset(method):
def check_random_dataset(method): def check_random_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -794,7 +794,7 @@ def check_add_column(method):
def check_cluedataset(method): def check_cluedataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -824,7 +824,7 @@ def check_cluedataset(method):
def check_csvdataset(method): def check_csvdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -871,7 +871,7 @@ def check_csvdataset(method):
def check_textfiledataset(method): def check_textfiledataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -964,7 +964,7 @@ def check_gnn_graphdata(method):
def check_gnn_get_all_nodes(method): def check_gnn_get_all_nodes(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function.""" """A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -977,7 +977,7 @@ def check_gnn_get_all_nodes(method):
def check_gnn_get_all_edges(method): def check_gnn_get_all_edges(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_edges` function.""" """A wrapper that wraps a parameter checker around the GNN `get_all_edges` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -990,7 +990,7 @@ def check_gnn_get_all_edges(method):
def check_gnn_get_nodes_from_edges(method): def check_gnn_get_nodes_from_edges(method):
"""A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function.""" """A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1003,7 +1003,7 @@ def check_gnn_get_nodes_from_edges(method):
def check_gnn_get_all_neighbors(method): def check_gnn_get_all_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function.""" """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1018,7 +1018,7 @@ def check_gnn_get_all_neighbors(method):
def check_gnn_get_sampled_neighbors(method): def check_gnn_get_sampled_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function.""" """A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1046,7 +1046,7 @@ def check_gnn_get_sampled_neighbors(method):
def check_gnn_get_neg_sampled_neighbors(method): def check_gnn_get_neg_sampled_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" """A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1062,7 +1062,7 @@ def check_gnn_get_neg_sampled_neighbors(method):
def check_gnn_random_walk(method): def check_gnn_random_walk(method):
"""A wrapper that wraps a parameter checker to the GNN `random_walk` function.""" """A wrapper that wraps a parameter checker around the GNN `random_walk` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1110,7 +1110,7 @@ def check_aligned_list(param, param_name, member_type):
def check_gnn_get_node_feature(method): def check_gnn_get_node_feature(method):
"""A wrapper that wraps a parameter checker to the GNN `get_node_feature` function.""" """A wrapper that wraps a parameter checker around the GNN `get_node_feature` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1132,7 +1132,7 @@ def check_gnn_get_node_feature(method):
def check_gnn_get_edge_feature(method): def check_gnn_get_edge_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function.""" """A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1154,7 +1154,7 @@ def check_gnn_get_edge_feature(method):
def check_numpyslicesdataset(method): def check_numpyslicesdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -1195,17 +1195,17 @@ def check_numpyslicesdataset(method):
def check_paddeddataset(method): def check_paddeddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(PaddedDataset).""" """A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs) _, param_dict = parse_user_args(method, *args, **kwargs)
paddedSamples = param_dict.get("padded_samples") padded_samples = param_dict.get("padded_samples")
if not paddedSamples: if not padded_samples:
raise ValueError("Argument padded_samples cannot be empty") raise ValueError("Argument padded_samples cannot be empty")
type_check(paddedSamples, (list,), "padded_samples") type_check(padded_samples, (list,), "padded_samples")
type_check(paddedSamples[0], (dict,), "padded_element") type_check(padded_samples[0], (dict,), "padded_element")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method

@ -328,7 +328,7 @@ def check_from_dataset(method):
return new_method return new_method
def check_slidingwindow(method): def check_slidingwindow(method):
"""A wrapper that wrap a parameter checker to the original function(sliding window operation).""" """A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -496,4 +496,3 @@ def check_save_model(method):
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method

@ -31,8 +31,8 @@ class OneHot(cde.OneHotOp):
Tensor operation to apply one hot encoding. Tensor operation to apply one hot encoding.
Args: Args:
num_classes (int): Number of classes of the label num_classes (int): Number of classes of the label.
it should be bigger than largest label number in dataset. It should be larger than the largest label number in the dataset.
Raises: Raises:
RuntimeError: feature size is bigger than num_classes. RuntimeError: feature size is bigger than num_classes.

@ -27,8 +27,9 @@ class OneHotOp:
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous. Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
Args: Args:
num_classes (int): Num class of object in dataset, type is int and value over 0. num_classes (int): Number of classes of objects in dataset. Value must be larger than 0.
smoothing_rate (float): The adjustable Hyper parameter decides the label smoothing level , 0.0 means not do it. smoothing_rate (float, optional): Adjustable hyperparameter for label smoothing level.
(Default=0.0 means no smoothing is applied.)
""" """
@check_one_hot_op @check_one_hot_op

@ -152,7 +152,7 @@ def check_erasing_value(value):
def check_crop(method): def check_crop(method):
"""A wrapper that wraps a parameter checker to the original function(crop operation).""" """A wrapper that wraps a parameter checker around the original function(crop operation)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -165,7 +165,7 @@ def check_crop(method):
def check_posterize(method): def check_posterize(method):
""""A wrapper that wraps a parameter checker to the original function(posterize operation).""" """A wrapper that wraps a parameter checker around the original function(posterize operation)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -187,7 +187,7 @@ def check_posterize(method):
def check_resize_interpolation(method): def check_resize_interpolation(method):
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation).""" """A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -202,7 +202,7 @@ def check_resize_interpolation(method):
def check_resize(method): def check_resize(method):
"""A wrapper that wraps a parameter checker to the original function(resize operation).""" """A wrapper that wraps a parameter checker around the original function(resize operation)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -235,7 +235,7 @@ def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
def check_random_resize_crop(method): def check_random_resize_crop(method):
"""A wrapper that wraps a parameter checker to the original function(random resize crop operation).""" """A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -250,7 +250,7 @@ def check_random_resize_crop(method):
def check_prob(method): def check_prob(method):
"""A wrapper that wraps a parameter checker(check the probability) to the original function.""" """A wrapper that wraps a parameter checker (to confirm probability) around the original function."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -264,7 +264,7 @@ def check_prob(method):
def check_normalize_c(method): def check_normalize_c(method):
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in C++).""" """A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -277,7 +277,7 @@ def check_normalize_c(method):
def check_normalize_py(method): def check_normalize_py(method):
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in Python).""" """A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):

@ -86,8 +86,13 @@ def test_five_crop_error_msg():
transform = vision.ComposeOp(transforms) transform = vision.ComposeOp(transforms)
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError) as info:
data.create_tuple_iterator().__next__() for _ in data:
pass
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
# error msg comes from ToTensor()
assert error_msg in str(info.value)
def test_five_crop_md5(): def test_five_crop_md5():

@ -149,7 +149,7 @@ def test_random_color_py_md5():
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms = F.ComposeOp([F.Decode(), transforms = F.ComposeOp([F.Decode(),
F.RandomColor((0.1, 1.9)), F.RandomColor((2.0, 2.5)),
F.ToTensor()]) F.ToTensor()])
data = data.map(input_columns="image", operations=transforms()) data = data.map(input_columns="image", operations=transforms())
@ -244,12 +244,12 @@ def test_random_color_c_errors():
if __name__ == "__main__": if __name__ == "__main__":
test_random_color_py() test_random_color_py()
test_random_color_py(plot=True) test_random_color_py(plot=True)
test_random_color_py(degrees=(0.5, 1.5), plot=True) test_random_color_py(degrees=(2.0, 2.5), plot=True) # Test with degree values that show more obvious transformation
test_random_color_py_md5() test_random_color_py_md5()
test_random_color_c() test_random_color_c()
test_random_color_c(plot=True) test_random_color_c(plot=True)
test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False) test_random_color_c(degrees=(2.0, 2.5), plot=True, run_golden=False) # Test with degree values that show more obvious transformation
test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False) test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False)
test_compare_random_color_op(plot=True) test_compare_random_color_op(plot=True)
test_random_color_c_errors() test_random_color_c_errors()

@ -103,7 +103,7 @@ def test_random_sharpness_py_md5():
# define map operations # define map operations
transforms = [ transforms = [
F.Decode(), F.Decode(),
F.RandomSharpness((0.1, 1.9)), F.RandomSharpness((20.0, 25.0)),
F.ToTensor() F.ToTensor()
] ]
transform = F.ComposeOp(transforms) transform = F.ComposeOp(transforms)
@ -193,7 +193,7 @@ def test_random_sharpness_c_md5():
# define map operations # define map operations
transforms = [ transforms = [
C.Decode(), C.Decode(),
C.RandomSharpness((0.1, 1.9)) C.RandomSharpness((10.0, 15.0))
] ]
# Generate dataset # Generate dataset
@ -337,14 +337,16 @@ def test_random_sharpness_invalid_params():
if __name__ == "__main__": if __name__ == "__main__":
test_random_sharpness_py(plot=True) test_random_sharpness_py(plot=True)
test_random_sharpness_py(None, plot=True) # test with default values test_random_sharpness_py(None, plot=True) # Test with default values
test_random_sharpness_py(degrees=(20.0, 25.0), plot=True) # Test with degree values that show more obvious transformation
test_random_sharpness_py_md5() test_random_sharpness_py_md5()
test_random_sharpness_c(plot=True) test_random_sharpness_c(plot=True)
test_random_sharpness_c(None, plot=True) # test with default values test_random_sharpness_c(None, plot=True) # test with default values
test_random_sharpness_c(degrees=[10, 15], plot=True) # Test with degrees values that show more obvious transformation
test_random_sharpness_c_md5() test_random_sharpness_c_md5()
test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True) test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True)
test_random_sharpness_c_py(degrees=[1, 1], plot=True) test_random_sharpness_c_py(degrees=[1, 1], plot=True)
test_random_sharpness_c_py(degrees=[10, 10], plot=True) test_random_sharpness_c_py(degrees=[10, 10], plot=True)
test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True) test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True)
test_random_sharpness_one_channel_c(degrees=None, plot=True) # test with default values test_random_sharpness_one_channel_c(degrees=None, plot=True) # Test with default values
test_random_sharpness_invalid_params() test_random_sharpness_invalid_params()

@ -303,7 +303,7 @@ def test_repeat_count0():
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(0) data1.repeat(0)
assert "count" in str(info) assert "count" in str(info.value)
def test_repeat_countneg2(): def test_repeat_countneg2():
""" """
@ -313,7 +313,7 @@ def test_repeat_countneg2():
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(-2) data1.repeat(-2)
assert "count" in str(info) assert "count" in str(info.value)
if __name__ == "__main__": if __name__ == "__main__":
test_tf_repeat_01() test_tf_repeat_01()

Loading…
Cancel
Save